def backward(ctx, grad_output, grad_LU=None): X, a = ctx.saved_variables grad_b, _ = torch.gesv(grad_output, a.t()) grad_a = -torch.mm(grad_b, X.t()) return grad_b, grad_a