def batch_cosine(self, doc_batch_proj, query_batch_proj):
dot_prod = T.batched_dot(doc_batch_proj, query_batch_proj)
doc_square = doc_batch_proj ** 2
query_square = query_batch_proj ** 2
doc_norm = (T.sqrt(doc_square.sum(axis = 1))).sum()
query_norm = T.sqrt(query_square.sum(axis = 1)).sum()
batch_cosine_vec = dot_prod/(doc_norm * query_norm)
return batch_cosine_vec
评论列表
文章目录