def test_batched_index_select(self):
indices = numpy.array([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]]])
# Each element is a vector of it's index.
targets = torch.ones([2, 10, 3]).cumsum(1) - 1
# Make the second batch double it's index so they're different.
targets[1, :, :] *= 2
indices = Variable(torch.LongTensor(indices))
targets = Variable(targets)
selected = util.batched_index_select(targets, indices)
assert list(selected.size()) == [2, 2, 2, 3]
ones = numpy.ones([3])
numpy.testing.assert_array_equal(selected[0, 0, 0, :].data.numpy(), ones)
numpy.testing.assert_array_equal(selected[0, 0, 1, :].data.numpy(), ones * 2)
numpy.testing.assert_array_equal(selected[0, 1, 0, :].data.numpy(), ones * 3)
numpy.testing.assert_array_equal(selected[0, 1, 1, :].data.numpy(), ones * 4)
numpy.testing.assert_array_equal(selected[1, 0, 0, :].data.numpy(), ones * 10)
numpy.testing.assert_array_equal(selected[1, 0, 1, :].data.numpy(), ones * 12)
numpy.testing.assert_array_equal(selected[1, 1, 0, :].data.numpy(), ones * 14)
numpy.testing.assert_array_equal(selected[1, 1, 1, :].data.numpy(), ones * 16)
评论列表
文章目录