spbatch.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:qpth 作者: locuslab 项目源码 文件源码
def cat_kkt(Qi, Qv, Qsz, Gi, Gv, Gsz, Ai, Av, Asz, Di, Dv, Dsz, eps):
    nBatch = Qv.size(0)

    nineq, nz = Gsz
    neq, _ = Asz

    Di = Di + nz

    Gi_L = Gi.clone()
    Gi_L[0, :] += nz + nineq
    Gv_L = Gv

    Gi_U = torch.stack([Gi[1, :], Gi[0, :]])
    Gi_U[1, :] += nz + nineq
    Gv_U = Gv

    Ai_L = Ai.clone()
    Ai_L[0, :] += nz + 2 * nineq
    Av_L = Av

    Ai_U = torch.stack([Ai[1, :], Ai[0, :]])
    Ai_U[1, :] += nz + 2 * nineq
    Av_U = Av

    Ii_L = type(Qi)([range(nineq), range(nineq)])
    Ii_U = Ii_L.clone()
    Ii_L[0, :] += nz + nineq
    Ii_L[1, :] += nz
    Ii_U[0, :] += nz
    Ii_U[1, :] += nz + nineq
    Iv_L = type(Qv)(nBatch, nineq).fill_(1.0)
    Iv_U = Iv_L.clone()

    Ii_11 = type(Qi)([range(nz + nineq), range(nz + nineq)])
    Iv_11 = type(Qv)(nBatch, nz + nineq).fill_(eps)
    Ii_22 = type(Qi)([range(nz + nineq, nz + 2 * nineq + neq),
                      range(nz + nineq, nz + 2 * nineq + neq)])
    Iv_22 = type(Qv)(nBatch, nineq + neq).fill_(-eps)

    Ki = torch.cat((Qi, Di, Gi_L, Gi_U, Ai_L, Ai_U,
                    Ii_L, Ii_U, Ii_11, Ii_22), 1)
    Kv = torch.cat((Qv, Dv, Gv_L, Gv_U, Av_L, Av_U,
                    Iv_L, Iv_U, Iv_11, Iv_22), 1)
    k = nz + 2 * nineq + neq
    Ksz = torch.Size([k, k])

    I = torch.LongTensor(np.lexsort(
        (Ki[1].cpu().numpy(), Ki[0].cpu().numpy()))).cuda()
    Ki = Ki.t()[I].t().contiguous()
    Kv = Kv.t()[I].t().contiguous()

    Ks = [torch.cuda.sparse.DoubleTensor(
        Ki, Kv[i], Ksz).coalesce() for i in range(nBatch)]
    Ki = Ks[0]._indices()
    Kv = torch.stack([Ks[i]._values() for i in range(nBatch)])

    Didx = torch.nonzero(
        (Ki[0] == Ki[1]).__and__(nz <= Ki[0]).__and__(Ki[0] < nz + nineq)).squeeze()

    return Ks, [Ki, Kv, Ksz], Didx
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号