dirichlet_test.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def testCovarianceFromSampling(self):
    alpha = np.array([[1., 2, 3],
                      [2.5, 4, 0.01]], dtype=np.float32)
    with self.test_session() as sess:
      dist = dirichlet_lib.Dirichlet(alpha)  # batch_shape=[2], event_shape=[3]
      x = dist.sample(int(250e3), seed=1)
      sample_mean = math_ops.reduce_mean(x, 0)
      x_centered = x - sample_mean[None, ...]
      sample_cov = math_ops.reduce_mean(math_ops.matmul(
          x_centered[..., None], x_centered[..., None, :]), 0)
      sample_var = array_ops.matrix_diag_part(sample_cov)
      sample_stddev = math_ops.sqrt(sample_var)
      [
          sample_mean_,
          sample_cov_,
          sample_var_,
          sample_stddev_,
          analytic_mean,
          analytic_cov,
          analytic_var,
          analytic_stddev,
      ] = sess.run([
          sample_mean,
          sample_cov,
          sample_var,
          sample_stddev,
          dist.mean(),
          dist.covariance(),
          dist.variance(),
          dist.stddev(),
      ])
      self.assertAllClose(sample_mean_, analytic_mean, atol=0., rtol=0.04)
      self.assertAllClose(sample_cov_, analytic_cov, atol=0., rtol=0.06)
      self.assertAllClose(sample_var_, analytic_var, atol=0., rtol=0.03)
      self.assertAllClose(sample_stddev_, analytic_stddev, atol=0., rtol=0.02)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号