def testSampleWithSampleShape(self):
with self.test_session():
mu = self._rng.rand(3, 5, 2)
chol, sigma = self._random_chol(3, 5, 2, 2)
mvn = distributions.MultivariateNormalCholesky(mu, chol)
samples_val = mvn.sample((10, 11, 12), seed=137).eval()
# Check sample shape
self.assertEqual((10, 11, 12, 3, 5, 2), samples_val.shape)
# Check sample means
x = samples_val[:, :, :, 1, 1, :]
self.assertAllClose(
x.reshape(10 * 11 * 12, 2).mean(axis=0), mu[1, 1], atol=1e-2)
# Check that log_prob(samples) works
log_prob_val = mvn.log_prob(samples_val).eval()
x_log_pdf = log_prob_val[:, :, :, 1, 1]
expected_log_pdf = stats.multivariate_normal(
mean=mu[1, 1, :], cov=sigma[1, 1, :, :]).logpdf(x)
self.assertAllClose(expected_log_pdf, x_log_pdf)
mvn_test.py 文件源码
python
阅读 31
收藏 0
点赞 0
评论 0
评论列表
文章目录