def test_batchnorm_raises_error_if_running_mean_is_not_same_size_as_input(self):
input = Variable(torch.rand(2, 10))
running_var = torch.rand(10)
wrong_sizes = [9, 11]
for size in wrong_sizes:
with self.assertRaises(RuntimeError):
F.batch_norm(input, torch.rand(size), running_var)
评论列表
文章目录