gamma_test.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def testGammaGammaKL(self):
    alpha0 = np.array([3.])
    beta0 = np.array([1., 2., 3., 1.5, 2.5, 3.5])

    alpha1 = np.array([0.4])
    beta1 = np.array([0.5, 1., 1.5, 2., 2.5, 3.])

    # Build graph.
    with self.test_session() as sess:
      g0 = gamma_lib.Gamma(alpha=alpha0, beta=beta0)
      g1 = gamma_lib.Gamma(alpha=alpha1, beta=beta1)
      x = g0.sample(int(1e4), seed=0)
      kl_sample = math_ops.reduce_mean(g0.log_prob(x) - g1.log_prob(x), 0)
      kl_actual = kullback_leibler.kl(g0, g1)

    # Execute graph.
    [kl_sample_, kl_actual_] = sess.run([kl_sample, kl_actual])

    kl_expected = ((alpha0 - alpha1) * special.digamma(alpha0)
                   + special.gammaln(alpha1)
                   - special.gammaln(alpha0)
                   + alpha1 * np.log(beta0)
                   - alpha1 * np.log(beta1)
                   + alpha0 * (beta1 / beta0 - 1.))

    self.assertEqual(beta0.shape, kl_actual.get_shape())
    self.assertAllClose(kl_expected, kl_actual_, atol=0., rtol=1e-6)
    self.assertAllClose(kl_sample_, kl_actual_, atol=0., rtol=1e-2)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号