def transpose_packed_sequence(ps):
"""
Goes from a TxB packed sequence to a BxT or vice versa. Assumes that nothing is a variable
:param ps: PackedSequence
:return:
"""
data, batch_sizes = ps
seq_lens = transpose_batch_sizes(batch_sizes)
# Put things in the permutation matrix one way, take out another way
perm_mat = torch.IntTensor(batch_sizes[0], len(batch_sizes)).long().zero_()
cur = 0
for i, sl in enumerate(seq_lens):
for col_ind in range(sl):
perm_mat[i, col_ind] = cur + col_ind
cur += sl
perm = pack_padded_sequence(perm_mat, seq_lens, batch_first=True).data
return PackedSequence(data[perm], seq_lens)
评论列表
文章目录