def test_scatterFill(self):
m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)
elems_per_row = random.randint(1, 10)
dim = random.randrange(3)
val = random.random()
idx_size = [m, n, o]
idx_size[dim] = elems_per_row
idx = torch.LongTensor().resize_(*idx_size)
self._fill_indices(idx, dim, ([m, n, o])[dim], elems_per_row, m, n, o)
actual = torch.zeros(m, n, o).scatter_(dim, idx, val)
expected = torch.zeros(m, n, o)
for i in range(idx_size[0]):
for j in range(idx_size[1]):
for k in range(idx_size[2]):
ii = [i, j, k]
ii[dim] = idx[i,j,k]
expected[tuple(ii)] = val
self.assertEqual(actual, expected, 0)
idx[0][0][0] = 28
self.assertRaises(RuntimeError, lambda: torch.zeros(m, n, o).scatter_(dim, idx, val))
评论列表
文章目录