def shape_transform(x): """ Tranform the size of the tensors to fit for conv input. """ return torch.unsqueeze(torch.transpose(x, 1, 2), 3)