DCN.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:DCN 作者: alexnowakvila 项目源码 文件源码
def sort_by_embeddings(self, Phis, Inputs_N, e):
        ind = torch.sort(e, 1)[1].squeeze()
        for i, phis in enumerate(Phis):
            # rearange phis
            phis_out = (torch.gather(Phis[i], 1, ind.unsqueeze(2)
                        .expand_as(phis)))
            Phis[i] = (torch.gather(phis_out, 2, ind.unsqueeze(1)
                       .expand_as(phis)))
            # rearange inputs
            Inputs_N[i] = torch.gather(Inputs_N[i], 1,
                                       ind.unsqueeze(2).expand_as(Inputs_N[i]))
        return Phis, Inputs_N
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号