def factor_kkt(U_S, R, d): """ Factor the U22 block that we can only do after we know D. """ nineq = R.size(0) U_S[-nineq:, -nineq:] = torch.potrf(R + torch.diag(1 / d.cpu()).type_as(d))