def forward(self, cont_repres, other_cont_repres):
"""
Args:
cont_repres - [batch_size, this_len, context_lstm_dim]
other_cont_repres - [batch_size, other_len, context_lstm_dim]
Return:
size - [bsz, this_len, mp_dim*2]
"""
bsz = cont_repres.size(0)
this_len = cont_repres.size(1)
other_len = other_cont_repres.size(1)
cont_repres = cont_repres.view(-1, self.cont_dim)
other_cont_repres = other_cont_repres.view(-1, self.cont_dim)
cont_repres = multi_perspective_expand_for_2D(cont_repres, self.weight)
other_cont_repres = multi_perspective_expand_for_2D(other_cont_repres, self.weight)
cont_repres = cont_repres.view(bsz, this_len, self.mp_dim, self.cont_dim)
other_cont_repres = other_cont_repres.view(bsz, other_len, self.mp_dim, self.cont_dim)
# [bsz, this_len, 1, self.mp_dim, self.cont_dim]
cont_repres = cont_repres.unsqueeze(2)
# [bsz, 1, other_len, self.mp_dim, self.cont_dim]
other_cont_repres = other_cont_repres.unsqueeze(1)
# [bsz, this_len, other_len, self.mp_dim]fanruan
simi = cosine_similarity(cont_repres, other_cont_repres, cont_repres.dim()-1)
t_max, _ = simi.max(2)
t_mean = simi.mean(2)
return torch.cat((t_max, t_mean), 2)
评论列表
文章目录