mvn_test.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def testLogPDFXIsHigherRank(self):
    with self.test_session():
      mu = self._rng.rand(2)
      chol, sigma = self._random_chol(2, 2)
      mvn = distributions.MultivariateNormalCholesky(mu, chol)
      x = self._rng.rand(3, 2)

      log_pdf = mvn.log_prob(x)
      pdf = mvn.prob(x)

      scipy_mvn = stats.multivariate_normal(mean=mu, cov=sigma)

      expected_log_pdf = scipy_mvn.logpdf(x)
      expected_pdf = scipy_mvn.pdf(x)
      self.assertEqual((3,), log_pdf.get_shape())
      self.assertEqual((3,), pdf.get_shape())
      self.assertAllClose(expected_log_pdf, log_pdf.eval())
      self.assertAllClose(expected_pdf, pdf.eval())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号