def test_stochastic(self):
x = Variable(torch.rand(2, 10), requires_grad=True)
stddevs = Variable(torch.rand(2, 10) * 5, requires_grad=True)
y = (x * 2).clamp(0, 1)
y = y / y.sum(1, True).expand_as(y)
samples_multi = y.multinomial(5)
samples_multi_flat = y[0].multinomial(5)
samples_bernoulli = y.bernoulli()
samples_norm = torch.normal(y)
samples_norm_std = torch.normal(y, stddevs)
z = samples_multi * 2 + 4
z = z + samples_multi_flat.unsqueeze(0).expand_as(samples_multi)
z = torch.cat([z, z], 1)
z = z.double()
z = z + samples_bernoulli + samples_norm + samples_norm_std
last_sample = torch.normal(z, 4)
z = last_sample + 2
self.assertFalse(z.requires_grad)
self.assertRaises(RuntimeError, lambda: z.backward(retain_graph=True))
samples_multi.reinforce(torch.randn(2, 5))
self.assertRaises(RuntimeError, lambda: z.backward(retain_graph=True))
samples_multi_flat.reinforce(torch.randn(5))
self.assertRaises(RuntimeError, lambda: z.backward(retain_graph=True))
samples_bernoulli.reinforce(torch.randn(2, 10))
self.assertRaises(RuntimeError, lambda: z.backward(retain_graph=True))
samples_norm.reinforce(torch.randn(2, 10))
self.assertRaises(RuntimeError, lambda: z.backward(retain_graph=True))
samples_norm_std.reinforce(torch.randn(2, 10))
# We don't have to specify rewards w.r.t. last_sample - it doesn't
# require gradient
last_sample.backward(retain_graph=True)
z.backward()
self.assertGreater(x.grad.data.abs().sum(), 0)
评论列表
文章目录