def factor_kkt(S_LU, R, d):
""" Factor the U22 block that we can only do after we know D. """
nBatch, nineq = d.size()
neq = S_LU[1].size(1) - nineq
# TODO: There's probably a better way to add a batched diagonal.
global factor_kkt_eye
if factor_kkt_eye is None or factor_kkt_eye.size() != d.size():
# print('Updating batchedEye size.')
factor_kkt_eye = torch.eye(nineq).repeat(
nBatch, 1, 1).type_as(R).byte()
T = R.clone()
T[factor_kkt_eye] += (1. / d).squeeze()
T_LU = btrifact_hack(T)
global shown_btrifact_warning
if shown_btrifact_warning or not T.is_cuda:
# TODO: Don't use pivoting in most cases because
# torch.btriunpack is inefficient here:
oldPivotsPacked = S_LU[1][:, -nineq:] - neq
oldPivots, _, _ = torch.btriunpack(
T_LU[0], oldPivotsPacked, unpack_data=False)
newPivotsPacked = T_LU[1]
newPivots, _, _ = torch.btriunpack(
T_LU[0], newPivotsPacked, unpack_data=False)
# Re-pivot the S_LU_21 block.
if neq > 0:
S_LU_21 = S_LU[0][:, -nineq:, :neq]
S_LU[0][:, -nineq:,
:neq] = newPivots.transpose(1, 2).bmm(oldPivots.bmm(S_LU_21))
# Add the new S_LU_22 block pivots.
S_LU[1][:, -nineq:] = newPivotsPacked + neq
# Add the new S_LU_22 block.
S_LU[0][:, -nineq:, -nineq:] = T_LU[0]
评论列表
文章目录