def test_gamma_shape(self):
alpha = Variable(torch.exp(torch.randn(2, 3)), requires_grad=True)
beta = Variable(torch.exp(torch.randn(2, 3)), requires_grad=True)
alpha_1d = Variable(torch.exp(torch.randn(1)), requires_grad=True)
beta_1d = Variable(torch.exp(torch.randn(1)), requires_grad=True)
self.assertEqual(Gamma(alpha, beta).sample().size(), (2, 3))
self.assertEqual(Gamma(alpha, beta).sample_n(5).size(), (5, 2, 3))
self.assertEqual(Gamma(alpha_1d, beta_1d).sample_n(1).size(), (1, 1))
self.assertEqual(Gamma(alpha_1d, beta_1d).sample().size(), (1,))
self.assertEqual(Gamma(0.5, 0.5).sample().size(), (1,))
self.assertEqual(Gamma(0.5, 0.5).sample_n(1).size(), (1,))
def ref_log_prob(idx, x, log_prob):
a = alpha.data.view(-1)[idx]
b = beta.data.view(-1)[idx]
expected = scipy.stats.gamma.logpdf(x, a, scale=1 / b)
self.assertAlmostEqual(log_prob, expected, places=3)
self._check_log_prob(Gamma(alpha, beta), ref_log_prob)
# This is a randomized test.
评论列表
文章目录