def solve_kkt(Q_LU, d, G, A, S_LU, rx, rs, rz, ry):
""" Solve KKT equations for the affine step"""
nineq, nz, neq, nBatch = get_sizes(G, A)
invQ_rx = rx.btrisolve(*Q_LU)
if neq > 0:
h = torch.cat((invQ_rx.unsqueeze(1).bmm(A.transpose(1, 2)).squeeze(1) - ry,
invQ_rx.unsqueeze(1).bmm(G.transpose(1, 2)).squeeze(1) + rs / d - rz), 1)
else:
h = invQ_rx.unsqueeze(1).bmm(G.transpose(1, 2)).squeeze(1) + rs / d - rz
w = -(h.btrisolve(*S_LU))
g1 = -rx - w[:, neq:].unsqueeze(1).bmm(G).squeeze(1)
if neq > 0:
g1 -= w[:, :neq].unsqueeze(1).bmm(A).squeeze(1)
g2 = -rs - w[:, neq:]
dx = g1.btrisolve(*Q_LU)
ds = g2 / d
dz = w[:, neq:]
dy = w[:, :neq] if neq > 0 else None
return dx, ds, dz, dy
评论列表
文章目录