single.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:qpth 作者: locuslab 项目源码 文件源码
def factor_solve_kkt(Q, D, G, A, rx, rs, rz, ry):
    nineq, nz, neq, _ = get_sizes(G, A)

    if neq > 0:
        H_ = torch.cat([torch.cat([Q, torch.zeros(nz, nineq).type_as(Q)], 1),
                        torch.cat([torch.zeros(nineq, nz).type_as(Q), D], 1)], 0)
        A_ = torch.cat([torch.cat([G, torch.eye(nineq).type_as(Q)], 1),
                        torch.cat([A, torch.zeros(neq, nineq).type_as(Q)], 1)], 0)
        g_ = torch.cat([rx, rs], 0)
        h_ = torch.cat([rz, ry], 0)
    else:
        H_ = torch.cat([torch.cat([Q, torch.zeros(nz, nineq).type_as(Q)], 1),
                        torch.cat([torch.zeros(nineq, nz).type_as(Q), D], 1)], 0)
        A_ = torch.cat([G, torch.eye(nineq).type_as(Q)], 1)
        g_ = torch.cat([rx, rs], 0)
        h_ = rz

    U_H_ = torch.potrf(H_)

    invH_A_ = torch.potrs(A_.t(), U_H_)
    invH_g_ = torch.potrs(g_.view(-1, 1), U_H_).view(-1)

    S_ = torch.mm(A_, invH_A_)
    U_S_ = torch.potrf(S_)
    t_ = torch.mv(A_, invH_g_).view(-1, 1) - h_
    w_ = -torch.potrs(t_, U_S_).view(-1)
    v_ = torch.potrs(-g_.view(-1, 1) - torch.mv(A_.t(), w_), U_H_).view(-1)

    return v_[:nz], v_[nz:], w_[:nineq], w_[nineq:] if neq > 0 else None
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号