def combine_matrices(self, prob_matrix, prob_matrix_scale,
perm, last=False):
# prob_matrix shape is bs x length x length + 1. Add extra column.
length = prob_matrix_scale.size()[2]
first = Variable(torch.zeros([self.batch_size, 1, length])).type(dtype)
first[:, 0, 0] = 1.0
prob_matrix_scale = torch.cat((first, prob_matrix_scale), 1)
# argmax
new_perm = self.discretize(prob_matrix_scale)
perm = torch.gather(perm, 1, new_perm)
# combine
prob_matrix = torch.bmm(prob_matrix_scale, prob_matrix)
return prob_matrix, prob_matrix_scale, perm
评论列表
文章目录