def backward(self, grad_output): input, other = self.saved_tensors grad_input = torch.cross(other, grad_output, self.dim) grad_other = torch.cross(grad_output, input, self.dim) return grad_input, grad_other