test_nn.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:pytorch 作者: ezyang 项目源码 文件源码
def test_AlphaDropout(self):
        # generate random tensor with zero mean and unit std
        input = torch.randn(5000)

        mean = input.mean()
        std = input.std()

        for p in [0.2, 0.5, 0.8]:
            module = nn.AlphaDropout(p)
            input_var = Variable(input, requires_grad=True)
            output = module(input_var)
            # output mean should be close to input mean
            self.assertLess(abs(output.data.mean() - mean), 0.1)
            # output std should be close to input std
            self.assertLess(abs(output.data.std() - std), 0.1)
            output.backward(input)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号