def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
r"""Unpacks the data and pivots from a batched LU factorization (btrifact) of a tensor.
Returns a tuple indexed by:
0: The pivots.
1: The L tensor.
2: The U tensor.
Arguments:
LU_data (Tensor): the packed LU factorization data
LU_pivots (Tensor): the packed LU factorization pivots
unpack_data (bool): flag indicating if the data should be unpacked
unpack_pivots (bool): tlag indicating if the pivots should be unpacked
"""
nBatch, sz, _ = LU_data.size()
if unpack_data:
I_U = torch.triu(torch.ones(sz, sz)).type_as(LU_data).byte().unsqueeze(0).expand(nBatch, sz, sz)
I_L = 1 - I_U
L = LU_data.new(LU_data.size()).zero_()
U = LU_data.new(LU_data.size()).zero_()
I_diag = torch.eye(sz).type_as(LU_data).byte().unsqueeze(0).expand(nBatch, sz, sz)
L[I_diag] = 1.0
L[I_L] = LU_data[I_L]
U[I_U] = LU_data[I_U]
else:
L = U = None
if unpack_pivots:
P = torch.eye(sz).type_as(LU_data).unsqueeze(0).repeat(nBatch, 1, 1)
for i in range(nBatch):
for j in range(sz):
k = LU_pivots[i, j] - 1
t = P[i, :, j].clone()
P[i, :, j] = P[i, :, k]
P[i, :, k] = t
else:
P = None
return P, L, U
评论列表
文章目录