def __call__(self, *inputs): outputs = [] for idx, _input in enumerate(inputs): _input = th.transpose(_input, self.dim1, self.dim2) outputs.append(_input) return outputs if idx > 1 else outputs[0]