def cosine_cont(repr_context, relevancy, norm=False):
"""
cosine siminlarity betwen context and relevancy
Args:
repr_context - [batch_size, other_len, context_lstm_dim]
relevancy - [batch_size, this_len, other_len]
Return:
size - [batch_size, this_len, context_lstm_dim]
"""
dim = repr_context.dim()
temp_relevancy = relevancy.unsqueeze(dim) # [batch_size, this_len, other_len, 1]
buff = repr_context.unsqueeze(1) # [batch_size, 1, other_len, context_lstm_dim]
buff = torch.mul(buff, temp_relevancy) # [batch_size, this_len, other_len, context_lstm_dim]
buff = buff.sum(2) # [batch_size, this_len, context_lstm_dim]
if norm:
relevancy = relevancy.sum(dim-1).clamp(min=1e-6) # [batch_size, this_len]
relevancy = relevancy.unsqueeze(2) # [batch_size, this_len, 1]
buff = buff.div(relevancy)
return buff
评论列表
文章目录