spbatch.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:qpth 作者: locuslab 项目源码 文件源码
def solve_kkt(Ks, K, Ktildes, Ktilde,
              rx, rs, rz, ry, niter=1):
    nBatch = len(Ks)
    nz = rx.size(1)
    nineq = rz.size(1)
    neq = ry.size(1)

    r = -torch.cat((rx, rs, rz, ry), 1)

    l = torch.spbqrfactsolve(*([r] + Ktilde))
    res = torch.stack([r[i] - torch.mm(Ks[i], l[i].unsqueeze(1))
                       for i in range(nBatch)])
    for k in range(niter):
        d = torch.spbqrfactsolve(*([res] + Ktilde))
        l = l + d
        res = torch.stack([r[i] - torch.mm(Ks[i], l[i].unsqueeze(1))
                           for i in range(nBatch)])

    solx = l[:, :nz]
    sols = l[:, nz:nz + nineq]
    solz = l[:, nz + nineq:nz + 2 * nineq]
    soly = l[:, nz + 2 * nineq:nz + 2 * nineq + neq]

    return solx, sols, solz, soly
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号