def test_masked_copy(self):
num_copy, num_dest = 3, 10
dest = torch.randn(num_dest)
src = torch.randn(num_copy)
mask = torch.ByteTensor((0, 0, 0, 0, 1, 0, 1, 0, 1, 0))
dest2 = dest.clone()
dest.masked_copy_(mask, src)
j = 0
for i in range(num_dest):
if mask[i]:
dest2[i] = src[j]
j += 1
self.assertEqual(dest, dest2, 0)
# make source bigger than number of 1s in mask
src = torch.randn(num_dest)
dest.masked_copy_(mask, src)
# make src smaller. this should fail
src = torch.randn(num_copy - 1)
with self.assertRaises(RuntimeError):
dest.masked_copy_(mask, src)
评论列表
文章目录