def forward(ctx, input, other, dim=-1): ctx.dim = dim ctx.save_for_backward(input, other) return torch.cross(input, other, ctx.dim)