torch_util.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:multiNLI_encoder 作者: easonnie 项目源码 文件源码
def pack_to_matching_matrix(s1, s2, cat_only=[False, False]):
    t1 = s1.size(0)
    t2 = s2.size(0)
    batch_size = s1.size(1)
    d = s1.size(2)

    expanded_p_s1 = s1.expand(t2, t1, batch_size, d)

    expanded_p_s2 = s2.view(t2, 1, batch_size, d)
    expanded_p_s2 = expanded_p_s2.expand(t2, t1, batch_size, d)

    if not cat_only[0] and not cat_only[1]:
        matrix = torch.cat((expanded_p_s1, expanded_p_s2), dim=3)
    elif not cat_only[0] and cat_only[1]:
        matrix = torch.cat((expanded_p_s1, expanded_p_s2, expanded_p_s1 * expanded_p_s2), dim=3)
    else:
        matrix = torch.cat((expanded_p_s1,
                            expanded_p_s2,
                            torch.abs(expanded_p_s1 - expanded_p_s2),
                            expanded_p_s1 * expanded_p_s2), dim=3)

    # matrix = torch.cat((expanded_p_s1,
    #                     expanded_p_s2), dim=3)

    return matrix
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号