def test_bucket_values(self):
indices = torch.LongTensor([1, 2, 7, 1, 56, 900])
bucketed_distances = util.bucket_values(indices)
numpy.testing.assert_array_equal(bucketed_distances.numpy(),
numpy.array([1, 2, 5, 1, 8, 9]))
评论列表
文章目录