DCN.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:DCN 作者: alexnowakvila 项目源码 文件源码
def eliminate_rows(self, prob_sc, ind, phis):
        """ eliminate rows of phis and prob_matrix scale """
        length = prob_sc.size()[1]
        mask = (prob_sc[:, :, 0] > 0.85).type(dtype)
        rang = (Variable(torch.range(0, length - 1).unsqueeze(0)
                .expand_as(mask)).
                type(dtype))
        ind_sc = torch.sort(rang * (1-mask) + length * mask, 1)[1]
        # permute prob_sc
        m = mask.unsqueeze(2).expand_as(prob_sc)
        mm = m.clone()
        mm[:, :, 1:] = 0
        prob_sc = (torch.gather(prob_sc * (1 - m) + mm, 1,
                   ind_sc.unsqueeze(2).expand_as(prob_sc)))
        # compose permutations
        ind = torch.gather(ind, 1, ind_sc)
        active = torch.gather(1-mask, 1, ind_sc)
        # permute phis
        active1 = active.unsqueeze(2).expand_as(phis)
        ind1 = ind.unsqueeze(2).expand_as(phis)
        active2 = active.unsqueeze(1).expand_as(phis)
        ind2 = ind.unsqueeze(1).expand_as(phis)
        phis_out = torch.gather(phis, 1, ind1) * active1
        phis_out = torch.gather(phis_out, 2, ind2) * active2
        return prob_sc, ind, phis_out, active
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号