def test_flatten_and_batch_shift_indices(self):
indices = numpy.array([[[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 9, 9, 9]],
[[2, 1, 0, 7],
[7, 7, 2, 3],
[0, 0, 4, 2]]])
indices = Variable(torch.LongTensor(indices))
shifted_indices = util.flatten_and_batch_shift_indices(indices, 10)
numpy.testing.assert_array_equal(shifted_indices.data.numpy(),
numpy.array([1, 2, 3, 4, 5, 6, 7, 8, 9,
9, 9, 9, 12, 11, 10, 17, 17,
17, 12, 13, 10, 10, 14, 12]))
评论列表
文章目录