def test_importance_prior(self):
posterior = pyro.infer.Importance(self.model, guide=None, num_samples=10000)
marginal = pyro.infer.Marginal(posterior)
posterior_samples = [marginal() for i in range(1000)]
posterior_mean = torch.mean(torch.cat(posterior_samples))
posterior_stddev = torch.std(torch.cat(posterior_samples), 0)
self.assertEqual(0, torch.norm(posterior_mean - self.mu_mean).data[0],
prec=0.01)
self.assertEqual(0, torch.norm(posterior_stddev - self.mu_stddev).data[0],
prec=0.1)
评论列表
文章目录