def testLogPDFXIsHigherRank(self):
with self.test_session():
mu = self._rng.rand(2)
chol, sigma = self._random_chol(2, 2)
mvn = distributions.MultivariateNormalCholesky(mu, chol)
x = self._rng.rand(3, 2)
log_pdf = mvn.log_prob(x)
pdf = mvn.prob(x)
scipy_mvn = stats.multivariate_normal(mean=mu, cov=sigma)
expected_log_pdf = scipy_mvn.logpdf(x)
expected_pdf = scipy_mvn.pdf(x)
self.assertEqual((3,), log_pdf.get_shape())
self.assertEqual((3,), pdf.get_shape())
self.assertAllClose(expected_log_pdf, log_pdf.eval())
self.assertAllClose(expected_pdf, pdf.eval())
mvn_test.py 文件源码
python
阅读 25
收藏 0
点赞 0
评论 0
评论列表
文章目录