def cat_kkt(Qi, Qv, Qsz, Gi, Gv, Gsz, Ai, Av, Asz, Di, Dv, Dsz, eps):
nBatch = Qv.size(0)
nineq, nz = Gsz
neq, _ = Asz
Di = Di + nz
Gi_L = Gi.clone()
Gi_L[0, :] += nz + nineq
Gv_L = Gv
Gi_U = torch.stack([Gi[1, :], Gi[0, :]])
Gi_U[1, :] += nz + nineq
Gv_U = Gv
Ai_L = Ai.clone()
Ai_L[0, :] += nz + 2 * nineq
Av_L = Av
Ai_U = torch.stack([Ai[1, :], Ai[0, :]])
Ai_U[1, :] += nz + 2 * nineq
Av_U = Av
Ii_L = type(Qi)([range(nineq), range(nineq)])
Ii_U = Ii_L.clone()
Ii_L[0, :] += nz + nineq
Ii_L[1, :] += nz
Ii_U[0, :] += nz
Ii_U[1, :] += nz + nineq
Iv_L = type(Qv)(nBatch, nineq).fill_(1.0)
Iv_U = Iv_L.clone()
Ii_11 = type(Qi)([range(nz + nineq), range(nz + nineq)])
Iv_11 = type(Qv)(nBatch, nz + nineq).fill_(eps)
Ii_22 = type(Qi)([range(nz + nineq, nz + 2 * nineq + neq),
range(nz + nineq, nz + 2 * nineq + neq)])
Iv_22 = type(Qv)(nBatch, nineq + neq).fill_(-eps)
Ki = torch.cat((Qi, Di, Gi_L, Gi_U, Ai_L, Ai_U,
Ii_L, Ii_U, Ii_11, Ii_22), 1)
Kv = torch.cat((Qv, Dv, Gv_L, Gv_U, Av_L, Av_U,
Iv_L, Iv_U, Iv_11, Iv_22), 1)
k = nz + 2 * nineq + neq
Ksz = torch.Size([k, k])
I = torch.LongTensor(np.lexsort(
(Ki[1].cpu().numpy(), Ki[0].cpu().numpy()))).cuda()
Ki = Ki.t()[I].t().contiguous()
Kv = Kv.t()[I].t().contiguous()
Ks = [torch.cuda.sparse.DoubleTensor(
Ki, Kv[i], Ksz).coalesce() for i in range(nBatch)]
Ki = Ks[0]._indices()
Kv = torch.stack([Ks[i]._values() for i in range(nBatch)])
Didx = torch.nonzero(
(Ki[0] == Ki[1]).__and__(nz <= Ki[0]).__and__(Ki[0] < nz + nineq)).squeeze()
return Ks, [Ki, Kv, Ksz], Didx
评论列表
文章目录