def test_batchnorm_raises_error_if_weight_is_not_same_size_as_input(self):
input = Variable(torch.rand(2, 10))
running_mean = torch.rand(10)
running_var = torch.rand(10)
wrong_sizes = [9, 11]
for size in wrong_sizes:
with self.assertRaises(RuntimeError):
F.batch_norm(input, running_mean, running_var, weight=Parameter(torch.rand(size)))
评论列表
文章目录