def test_exponential(self):
rate = Variable(torch.randn(5, 5).abs(), requires_grad=True)
rate_1d = Variable(torch.randn(1).abs(), requires_grad=True)
self.assertEqual(Exponential(rate).sample().size(), (5, 5))
self.assertEqual(Exponential(rate).sample((7,)).size(), (7, 5, 5))
self.assertEqual(Exponential(rate_1d).sample((1,)).size(), (1, 1))
self.assertEqual(Exponential(rate_1d).sample().size(), (1,))
self.assertEqual(Exponential(0.2).sample((1,)).size(), (1,))
self.assertEqual(Exponential(50.0).sample((1,)).size(), (1,))
self._gradcheck_log_prob(Exponential, (rate,))
state = torch.get_rng_state()
eps = rate.new(rate.size()).exponential_()
torch.set_rng_state(state)
z = Exponential(rate).rsample()
z.backward(torch.ones_like(z))
self.assertEqual(rate.grad, -eps / rate**2)
rate.grad.zero_()
self.assertEqual(z.size(), (5, 5))
def ref_log_prob(idx, x, log_prob):
m = rate.data.view(-1)[idx]
expected = math.log(m) - m * x
self.assertAlmostEqual(log_prob, expected, places=3)
self._check_log_prob(Exponential(rate), ref_log_prob)
# This is a randomized test.
评论列表
文章目录