mvn_test.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def _testPDFShapes(self, mvn_dist, mu, sigma):
    with self.test_session() as sess:
      mvn = mvn_dist(mu, sigma)
      x = 2 * array_ops.ones_like(mu)

      log_pdf = mvn.log_prob(x)
      pdf = mvn.prob(x)

      mu_value = np.ones([3, 3, 2])
      sigma_value = np.zeros([3, 3, 2, 2])
      sigma_value[:] = np.identity(2)
      x_value = 2. * np.ones([3, 3, 2])
      feed_dict = {mu: mu_value, sigma: sigma_value}

      scipy_mvn = stats.multivariate_normal(
          mean=mu_value[(0, 0)], cov=sigma_value[(0, 0)])
      expected_log_pdf = scipy_mvn.logpdf(x_value[(0, 0)])
      expected_pdf = scipy_mvn.pdf(x_value[(0, 0)])

      log_pdf_evaled, pdf_evaled = sess.run([log_pdf, pdf], feed_dict=feed_dict)
      self.assertAllEqual([3, 3], log_pdf_evaled.shape)
      self.assertAllEqual([3, 3], pdf_evaled.shape)
      self.assertAllClose(expected_log_pdf, log_pdf_evaled[0, 0])
      self.assertAllClose(expected_pdf, pdf_evaled[0, 0])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号