module_utils.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:torch_light 作者: ne7ermore 项目源码 文件源码
def cosine_cont(repr_context, relevancy, norm=False):
    """
    cosine siminlarity betwen context and relevancy
    Args:
        repr_context - [batch_size, other_len, context_lstm_dim]
        relevancy - [batch_size, this_len, other_len]
    Return:
        size - [batch_size, this_len, context_lstm_dim]
    """
    dim = repr_context.dim()

    temp_relevancy = relevancy.unsqueeze(dim) # [batch_size, this_len, other_len, 1]
    buff = repr_context.unsqueeze(1) # [batch_size, 1, other_len, context_lstm_dim]
    buff = torch.mul(buff, temp_relevancy) # [batch_size, this_len, other_len, context_lstm_dim]
    buff = buff.sum(2) # [batch_size, this_len, context_lstm_dim]
    if norm:
        relevancy = relevancy.sum(dim-1).clamp(min=1e-6) # [batch_size, this_len]
        relevancy = relevancy.unsqueeze(2) # [batch_size, this_len, 1]
        buff = buff.div(relevancy)
    return buff
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号