test_nn.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:pytorch 作者: ezyang 项目源码 文件源码
def test_batchnorm_raises_error_if_weight_is_not_same_size_as_input(self):
        input = Variable(torch.rand(2, 10))
        running_mean = torch.rand(10)
        running_var = torch.rand(10)
        wrong_sizes = [9, 11]
        for size in wrong_sizes:
            with self.assertRaises(RuntimeError):
                F.batch_norm(input, running_mean, running_var, weight=Parameter(torch.rand(size)))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号