def test_normal(self):
mean = Variable(torch.randn(5, 5), requires_grad=True)
std = Variable(torch.randn(5, 5).abs(), requires_grad=True)
mean_1d = Variable(torch.randn(1), requires_grad=True)
std_1d = Variable(torch.randn(1), requires_grad=True)
mean_delta = torch.Tensor([1.0, 0.0])
std_delta = torch.Tensor([1e-5, 1e-5])
self.assertEqual(Normal(mean, std).sample().size(), (5, 5))
self.assertEqual(Normal(mean, std).sample_n(7).size(), (7, 5, 5))
self.assertEqual(Normal(mean_1d, std_1d).sample_n(1).size(), (1, 1))
self.assertEqual(Normal(mean_1d, std_1d).sample().size(), (1,))
self.assertEqual(Normal(0.2, .6).sample_n(1).size(), (1,))
self.assertEqual(Normal(-0.7, 50.0).sample_n(1).size(), (1,))
# sample check for extreme value of mean, std
self._set_rng_seed(1)
self.assertEqual(Normal(mean_delta, std_delta).sample(sample_shape=(1, 2)),
torch.Tensor([[[1.0, 0.0], [1.0, 0.0]]]),
prec=1e-4)
self._gradcheck_log_prob(Normal, (mean, std))
self._gradcheck_log_prob(Normal, (mean, 1.0))
self._gradcheck_log_prob(Normal, (0.0, std))
state = torch.get_rng_state()
eps = torch.normal(torch.zeros_like(mean), torch.ones_like(std))
torch.set_rng_state(state)
z = Normal(mean, std).rsample()
z.backward(torch.ones_like(z))
self.assertEqual(mean.grad, torch.ones_like(mean))
self.assertEqual(std.grad, eps)
mean.grad.zero_()
std.grad.zero_()
self.assertEqual(z.size(), (5, 5))
def ref_log_prob(idx, x, log_prob):
m = mean.data.view(-1)[idx]
s = std.data.view(-1)[idx]
expected = (math.exp(-(x - m) ** 2 / (2 * s ** 2)) /
math.sqrt(2 * math.pi * s ** 2))
self.assertAlmostEqual(log_prob, math.log(expected), places=3)
self._check_log_prob(Normal(mean, std), ref_log_prob)
# This is a randomized test.
评论列表
文章目录