test_autograd.py 文件源码

python
阅读 39 收藏 0 点赞 0 评论 0

项目:pytorch 作者: ezyang 项目源码 文件源码
def test_stochastic(self):
        x = Variable(torch.rand(2, 10), requires_grad=True)
        stddevs = Variable(torch.rand(2, 10) * 5, requires_grad=True)
        y = (x * 2).clamp(0, 1)
        y = y / y.sum(1, True).expand_as(y)
        samples_multi = y.multinomial(5)
        samples_multi_flat = y[0].multinomial(5)
        samples_bernoulli = y.bernoulli()
        samples_norm = torch.normal(y)
        samples_norm_std = torch.normal(y, stddevs)
        z = samples_multi * 2 + 4
        z = z + samples_multi_flat.unsqueeze(0).expand_as(samples_multi)
        z = torch.cat([z, z], 1)
        z = z.double()
        z = z + samples_bernoulli + samples_norm + samples_norm_std
        last_sample = torch.normal(z, 4)
        z = last_sample + 2
        self.assertFalse(z.requires_grad)

        self.assertRaises(RuntimeError, lambda: z.backward(retain_graph=True))
        samples_multi.reinforce(torch.randn(2, 5))
        self.assertRaises(RuntimeError, lambda: z.backward(retain_graph=True))
        samples_multi_flat.reinforce(torch.randn(5))
        self.assertRaises(RuntimeError, lambda: z.backward(retain_graph=True))
        samples_bernoulli.reinforce(torch.randn(2, 10))
        self.assertRaises(RuntimeError, lambda: z.backward(retain_graph=True))
        samples_norm.reinforce(torch.randn(2, 10))
        self.assertRaises(RuntimeError, lambda: z.backward(retain_graph=True))
        samples_norm_std.reinforce(torch.randn(2, 10))
        # We don't have to specify rewards w.r.t. last_sample - it doesn't
        # require gradient

        last_sample.backward(retain_graph=True)
        z.backward()

        self.assertGreater(x.grad.data.abs().sum(), 0)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号