def backward(self, grad_output):
input, other = self.saved_tensors
grad_input = torch.cross(other, grad_output, self.dim)
grad_other = torch.cross(grad_output, input, self.dim)
return grad_input, grad_other
评论列表
文章目录