def sort_batch_by_length(tensor: torch.autograd.Variable,
sequence_lengths: torch.autograd.Variable):
"""
Sort a batch first tensor by some specified lengths.
Parameters
----------
tensor : Variable(torch.FloatTensor), required.
A batch first Pytorch tensor.
sequence_lengths : Variable(torch.LongTensor), required.
A tensor representing the lengths of some dimension of the tensor which
we want to sort by.
Returns
-------
sorted_tensor : Variable(torch.FloatTensor)
The original tensor sorted along the batch dimension with respect to sequence_lengths.
sorted_sequence_lengths : Variable(torch.LongTensor)
The original sequence_lengths sorted by decreasing size.
restoration_indices : Variable(torch.LongTensor)
Indices into the sorted_tensor such that
``sorted_tensor.index_select(0, restoration_indices) == original_tensor``
permuation_index : Variable(torch.LongTensor)
The indices used to sort the tensor. This is useful if you want to sort many
tensors using the same ordering.
"""
if not isinstance(tensor, Variable) or not isinstance(sequence_lengths, Variable):
raise ConfigurationError("Both the tensor and sequence lengths must be torch.autograd.Variables.")
sorted_sequence_lengths, permutation_index = sequence_lengths.sort(0, descending=True)
sorted_tensor = tensor.index_select(0, permutation_index)
# This is ugly, but required - we are creating a new variable at runtime, so we
# must ensure it has the correct CUDA vs non-CUDA type. We do this by cloning and
# refilling one of the inputs to the function.
index_range = sequence_lengths.data.clone().copy_(torch.arange(0, len(sequence_lengths)))
# This is the equivalent of zipping with index, sorting by the original
# sequence lengths and returning the now sorted indices.
index_range = Variable(index_range.long())
_, reverse_mapping = permutation_index.sort(0, descending=False)
restoration_indices = index_range.index_select(0, reverse_mapping)
return sorted_tensor, sorted_sequence_lengths, restoration_indices, permutation_index
评论列表
文章目录