def batch_cosine(self, doc_batch_proj, query_batch_proj):
dot_prod = T.batched_dot(doc_batch_proj, query_batch_proj)
doc_square = T.sqr(doc_batch_proj)
query_square = T.sqr(query_batch_proj)
doc_norm = (T.sqrt(T.sum(doc_square, axis=1)))
query_norm = T.sqrt(T.sum(query_square, axis=1))
batch_cosine_vec = dot_prod/(doc_norm * query_norm)
return batch_cosine_vec
评论列表
文章目录