def test_sort_tensor_by_length(self):
tensor = torch.rand([5, 7, 9])
tensor[0, 3:, :] = 0
tensor[1, 4:, :] = 0
tensor[2, 1:, :] = 0
tensor[3, 5:, :] = 0
tensor = Variable(tensor)
sequence_lengths = Variable(torch.LongTensor([3, 4, 1, 5, 7]))
sorted_tensor, sorted_lengths, reverse_indices, _ = util.sort_batch_by_length(tensor, sequence_lengths)
# Test sorted indices are padded correctly.
numpy.testing.assert_array_equal(sorted_tensor[1, 5:, :].data.numpy(), 0.0)
numpy.testing.assert_array_equal(sorted_tensor[2, 4:, :].data.numpy(), 0.0)
numpy.testing.assert_array_equal(sorted_tensor[3, 3:, :].data.numpy(), 0.0)
numpy.testing.assert_array_equal(sorted_tensor[4, 1:, :].data.numpy(), 0.0)
assert sorted_lengths.data.equal(torch.LongTensor([7, 5, 4, 3, 1]))
# Test restoration indices correctly recover the original tensor.
assert sorted_tensor.index_select(0, reverse_indices).data.equal(tensor.data)
评论列表
文章目录