def pre_factor_kkt(Q, G, A):
""" Perform all one-time factorizations and cache relevant matrix products"""
nineq, nz, neq, nBatch = get_sizes(G, A)
try:
Q_LU = btrifact_hack(Q)
except:
raise RuntimeError("""
qpth Error: Cannot perform LU factorization on Q.
Please make sure that your Q matrix is PSD and has
a non-zero diagonal.
""")
# S = [ A Q^{-1} A^T A Q^{-1} G^T ]
# [ G Q^{-1} A^T G Q^{-1} G^T + D^{-1} ]
#
# We compute a partial LU decomposition of the S matrix
# that can be completed once D^{-1} is known.
# See the 'Block LU factorization' part of our website
# for more details.
G_invQ_GT = torch.bmm(G, G.transpose(1, 2).btrisolve(*Q_LU))
R = G_invQ_GT.clone()
S_LU_pivots = torch.IntTensor(range(1, 1 + neq + nineq)).unsqueeze(0) \
.repeat(nBatch, 1).type_as(Q).int()
if neq > 0:
invQ_AT = A.transpose(1, 2).btrisolve(*Q_LU)
A_invQ_AT = torch.bmm(A, invQ_AT)
G_invQ_AT = torch.bmm(G, invQ_AT)
LU_A_invQ_AT = btrifact_hack(A_invQ_AT)
P_A_invQ_AT, L_A_invQ_AT, U_A_invQ_AT = torch.btriunpack(*LU_A_invQ_AT)
P_A_invQ_AT = P_A_invQ_AT.type_as(A_invQ_AT)
S_LU_11 = LU_A_invQ_AT[0]
U_A_invQ_AT_inv = (P_A_invQ_AT.bmm(L_A_invQ_AT)
).btrisolve(*LU_A_invQ_AT)
S_LU_21 = G_invQ_AT.bmm(U_A_invQ_AT_inv)
T = G_invQ_AT.transpose(1, 2).btrisolve(*LU_A_invQ_AT)
S_LU_12 = U_A_invQ_AT.bmm(T)
S_LU_22 = torch.zeros(nBatch, nineq, nineq).type_as(Q)
S_LU_data = torch.cat((torch.cat((S_LU_11, S_LU_12), 2),
torch.cat((S_LU_21, S_LU_22), 2)),
1)
S_LU_pivots[:, :neq] = LU_A_invQ_AT[1]
R -= G_invQ_AT.bmm(T)
else:
S_LU_data = torch.zeros(nBatch, nineq, nineq).type_as(Q)
S_LU = [S_LU_data, S_LU_pivots]
return Q_LU, S_LU, R
评论列表
文章目录