mvn_test.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def testSampleWithSampleShape(self):
    with self.test_session():
      mu = self._rng.rand(3, 5, 2)
      chol, sigma = self._random_chol(3, 5, 2, 2)

      mvn = distributions.MultivariateNormalCholesky(mu, chol)
      samples_val = mvn.sample((10, 11, 12), seed=137).eval()

      # Check sample shape
      self.assertEqual((10, 11, 12, 3, 5, 2), samples_val.shape)

      # Check sample means
      x = samples_val[:, :, :, 1, 1, :]
      self.assertAllClose(
          x.reshape(10 * 11 * 12, 2).mean(axis=0), mu[1, 1], atol=1e-2)

      # Check that log_prob(samples) works
      log_prob_val = mvn.log_prob(samples_val).eval()
      x_log_pdf = log_prob_val[:, :, :, 1, 1]
      expected_log_pdf = stats.multivariate_normal(
          mean=mu[1, 1, :], cov=sigma[1, 1, :, :]).logpdf(x)
      self.assertAllClose(expected_log_pdf, x_log_pdf)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号