def forward(ctx, b, a): # TODO see if one can backprop through LU X, LU = torch.gesv(b, a) ctx.save_for_backward(X, a) ctx.mark_non_differentiable(LU) return X, LU