def _cosine_similarity_mask(self, simi):
# [bsz, a_len, q_len]
simi = torch.mul(simi, self.q_mask.unsqueeze(1).float()).clamp(min=eps)
simi = torch.mul(simi, self.a_mask.unsqueeze(2).float()).clamp(min=eps)
return simi
评论列表
文章目录