def reverse_padded_sequence(inputs, lengths, batch_first=False):
"""Reverses sequences according to their lengths.
Inputs should have size ``T x B x *`` if ``batch_first`` is False, or
``B x T x *`` if True. T is the length of the longest sequence (or larger),
B is the batch size, and * is any number of dimensions (including 0).
Arguments:
inputs (Variable): padded batch of variable length sequences.
lengths (list[int]): list of sequence lengths
batch_first (bool, optional): if True, inputs should be B x T x *.
Returns:
A Variable with the same size as inputs, but with each sequence
reversed according to its length.
"""
if not batch_first:
inputs = inputs.transpose(0, 1)
if inputs.size(0) != len(lengths):
raise ValueError('inputs incompatible with lengths.')
reversed_indices = [list(range(inputs.size(1)))
for _ in range(inputs.size(0))]
for i, length in enumerate(lengths):
if length > 0:
reversed_indices[i][:length] = reversed_indices[i][length-1::-1]
reversed_indices = (torch.LongTensor(reversed_indices).unsqueeze(2)
.expand_as(inputs))
reversed_indices = Variable(reversed_indices)
if inputs.is_cuda:
device = inputs.get_device()
reversed_indices = reversed_indices.cuda(device)
reversed_inputs = torch.gather(inputs, 1, reversed_indices)
if not batch_first:
reversed_inputs = reversed_inputs.transpose(0, 1)
return reversed_inputs
评论列表
文章目录