def test_MaskedSelect(self):
input = torch.randn(4, 5)
mask = torch.ByteTensor(4, 5).bernoulli_()
module = nn.MaskedSelect()
out = module.forward([input, mask])
self.assertEqual(input.masked_select(mask), out)
gradOut = torch.Tensor((20, 80))
input = torch.Tensor(((10, 20), (30, 40)))
inTarget = torch.Tensor(((20, 0), (0, 80)))
mask = torch.ByteTensor(((1, 0), (0, 1)))
module = nn.MaskedSelect()
module.forward([input, mask])
gradIn = module.backward([input, mask], gradOut)
self.assertEqual(inTarget, gradIn[0])
# Check that these don't raise errors
module.__repr__()
str(module)
评论列表
文章目录