def argmax(self, xs): xs = permutate_list(xs, argsort_list_descent(xs), inv=False) xs = F.transpose_sequence(xs) score, path = super(CRF, self).argmax(xs) path = F.transpose_sequence(path) return score, path