def forward(self, input, other): self.save_for_backward(input, other) return torch.cross(input, other, self.dim)