def forward(self, cont_repres, other_cont_first):
"""
Args:
cont_repres - [batch_size, this_len, context_lstm_dim]
other_cont_first - [batch_size, context_lstm_dim]
Return:
size - [batch_size, this_len, mp_dim]
"""
def expand(context, weight):
"""
Args:
[batch_size, this_len, context_lstm_dim]
[mp_dim, context_lstm_dim]
Return:
[batch_size, this_len, mp_dim, context_lstm_dim]
"""
# [1, 1, mp_dim, context_lstm_dim]
weight = weight.unsqueeze(0)
weight = weight.unsqueeze(0)
# [batch_size, this_len, 1, context_lstm_dim]
context = context.unsqueeze(2)
return torch.mul(context, weight)
cont_repres = expand(cont_repres, self.weight)
other_cont_first = multi_perspective_expand_for_2D(other_cont_first, self.weight)
# [batch_size, 1, mp_dim, context_lstm_dim]
other_cont_first = other_cont_first.unsqueeze(1)
return cosine_similarity(cont_repres, other_cont_first, cont_repres.dim()-1)
评论列表
文章目录