single.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:qpth 作者: locuslab 项目源码 文件源码
def pre_factor_kkt(Q, G, A):
    """ Perform all one-time factorizations and cache relevant matrix products"""
    nineq, nz, neq, _ = get_sizes(G, A)

    # S = [ A Q^{-1} A^T        A Q^{-1} G^T           ]
    #     [ G Q^{-1} A^T        G Q^{-1} G^T + D^{-1} ]

    U_Q = torch.potrf(Q)
    # partial cholesky of S matrix
    U_S = torch.zeros(neq + nineq, neq + nineq).type_as(Q)

    G_invQ_GT = torch.mm(G, torch.potrs(G.t(), U_Q))
    R = G_invQ_GT
    if neq > 0:
        invQ_AT = torch.potrs(A.t(), U_Q)
        A_invQ_AT = torch.mm(A, invQ_AT)
        G_invQ_AT = torch.mm(G, invQ_AT)

        # TODO: torch.potrf sometimes says the matrix is not PSD but
        # numpy does? I filed an issue at
        # https://github.com/pytorch/pytorch/issues/199
        try:
            U11 = torch.potrf(A_invQ_AT)
        except:
            U11 = torch.Tensor(np.linalg.cholesky(
                A_invQ_AT.cpu().numpy())).type_as(A_invQ_AT)

        # TODO: torch.trtrs is currently not implemented on the GPU
        # and we are using gesv as a workaround.
        U12 = torch.gesv(G_invQ_AT.t(), U11.t())[0]
        U_S[:neq, :neq] = U11
        U_S[:neq, neq:] = U12
        R -= torch.mm(U12.t(), U12)

    return U_Q, U_S, R
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号