def forward(self, repres, max_att):
"""
Args:
repres - [bsz, a_len|q_len, cont_dim]
max_att - [bsz, q_len|a_len, cont_dim]
Return:
size - [bsz, sentence_len, mp_dim]
"""
bsz = repres.size(0)
sent_len = repres.size(1)
repres = repres.view(-1, self.cont_dim)
max_att = max_att.view(-1, self.cont_dim)
repres = multi_perspective_expand_for_2D(repres, self.weight)
max_att = multi_perspective_expand_for_2D(max_att, self.weight)
temp = cosine_similarity(repres, max_att, repres.dim()-1)
return temp.view(bsz, sent_len, self.mp_dim)
评论列表
文章目录