def test_batchnorm_eval(self):
types = (torch.FloatTensor,)
if TEST_CUDA:
types += (torch.cuda.FloatTensor,)
for tp in types:
module = nn.BatchNorm1d(3).type(tp)
module.eval()
data = Variable(torch.rand(4, 3).type(tp), requires_grad=True)
grad = torch.rand(4, 3).type(tp)
# 1st pass
res1 = module(data)
res1.backward(grad)
grad1 = data.grad.data.clone()
# 2nd pass
if data.grad is not None:
data.grad.data.zero_()
res2 = module(data)
res2.backward(grad)
grad2 = data.grad.data.clone()
self.assertEqual(res1, res2)
self.assertEqual(grad1, grad2)
评论列表
文章目录