def combine_matrices(self, prob_matrix, prob_matrix_scale, perm):
# argmax
new_perm = self.discretize(prob_matrix_scale)
perm = torch.gather(perm, 1, new_perm)
prob_matrix = torch.bmm(prob_matrix_scale, prob_matrix)
return prob_matrix, perm
评论列表
文章目录