DCN.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:DCN 作者: alexnowakvila 项目源码 文件源码
def combine_matrices(self, prob_matrix, prob_matrix_scale,
                         perm, last=False):
        # prob_matrix shape is bs x length x length + 1. Add extra column.
        length = prob_matrix_scale.size()[2]
        first = Variable(torch.zeros([self.batch_size, 1, length])).type(dtype)
        first[:, 0, 0] = 1.0
        prob_matrix_scale = torch.cat((first, prob_matrix_scale), 1)
        # argmax
        new_perm = self.discretize(prob_matrix_scale)
        perm = torch.gather(perm, 1, new_perm)
        # combine
        prob_matrix = torch.bmm(prob_matrix_scale, prob_matrix)
        return prob_matrix, prob_matrix_scale, perm
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号