def outputs(self, input, prob_matrix, perm):
hard_output = (torch.gather(input, 1, perm.unsqueeze(2)
.expand_as(input)))
# soft argmax
soft_output = torch.bmm(prob_matrix, input)
return hard_output, soft_output