def solve_kkt(U_Q, d, G, A, U_S, rx, rs, rz, ry, dbg=False):
""" Solve KKT equations for the affine step"""
nineq, nz, neq, _ = get_sizes(G, A)
invQ_rx = torch.potrs(rx.view(-1, 1), U_Q).view(-1)
if neq > 0:
h = torch.cat([torch.mv(A, invQ_rx) - ry,
torch.mv(G, invQ_rx) + rs / d - rz], 0)
else:
h = torch.mv(G, invQ_rx) + rs / d - rz
w = -torch.potrs(h.view(-1, 1), U_S).view(-1)
g1 = -rx - torch.mv(G.t(), w[neq:])
if neq > 0:
g1 -= torch.mv(A.t(), w[:neq])
g2 = -rs - w[neq:]
dx = torch.potrs(g1.view(-1, 1), U_Q).view(-1)
ds = g2 / d
dz = w[neq:]
dy = w[:neq] if neq > 0 else None
# if np.all(np.array([x.norm() for x in [rx, rs, rz, ry]]) != 0):
if dbg:
import IPython
import sys
IPython.embed()
sys.exit(-1)
# if rs.norm() > 0: import IPython, sys; IPython.embed(); sys.exit(-1)
return dx, ds, dz, dy
评论列表
文章目录