def solve_kkt(Ks, K, Ktildes, Ktilde,
rx, rs, rz, ry, niter=1):
nBatch = len(Ks)
nz = rx.size(1)
nineq = rz.size(1)
neq = ry.size(1)
r = -torch.cat((rx, rs, rz, ry), 1)
l = torch.spbqrfactsolve(*([r] + Ktilde))
res = torch.stack([r[i] - torch.mm(Ks[i], l[i].unsqueeze(1))
for i in range(nBatch)])
for k in range(niter):
d = torch.spbqrfactsolve(*([res] + Ktilde))
l = l + d
res = torch.stack([r[i] - torch.mm(Ks[i], l[i].unsqueeze(1))
for i in range(nBatch)])
solx = l[:, :nz]
sols = l[:, nz:nz + nineq]
solz = l[:, nz + nineq:nz + 2 * nineq]
soly = l[:, nz + 2 * nineq:nz + 2 * nineq + neq]
return solx, sols, solz, soly
评论列表
文章目录