def test_VolumetricDropout(self):
p = 0.2
bsz = random.randint(1,5)
t = random.randint(1,5)
w = random.randint(1,5)
h = random.randint(1,5)
nfeats = 1000
input = torch.Tensor(bsz, nfeats, t, w, h).fill_(1)
module = nn.VolumetricDropout(p)
module.training()
output = module.forward(input)
self.assertLess(abs(output.mean() - (1-p)), 0.05)
gradInput = module.backward(input, input)
self.assertLess(abs(gradInput.mean() - (1-p)), 0.05)
# Check that these don't raise errors
module.__repr__()
str(module)
评论列表
文章目录