beta_test.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def testBetaBetaKL(self):
    with self.test_session() as sess:
      for shape in [(10,), (4, 5)]:
        a1 = 6.0 * np.random.random(size=shape) + 1e-4
        b1 = 6.0 * np.random.random(size=shape) + 1e-4
        a2 = 6.0 * np.random.random(size=shape) + 1e-4
        b2 = 6.0 * np.random.random(size=shape) + 1e-4
        # Take inverse softplus of values to test BetaWithSoftplusAB
        a1_sp = np.log(np.exp(a1) - 1.0)
        b1_sp = np.log(np.exp(b1) - 1.0)
        a2_sp = np.log(np.exp(a2) - 1.0)
        b2_sp = np.log(np.exp(b2) - 1.0)

        d1 = beta_lib.Beta(a=a1, b=b1)
        d2 = beta_lib.Beta(a=a2, b=b2)
        d1_sp = beta_lib.BetaWithSoftplusAB(a=a1_sp, b=b1_sp)
        d2_sp = beta_lib.BetaWithSoftplusAB(a=a2_sp, b=b2_sp)

        kl_expected = (special.betaln(a2, b2) - special.betaln(a1, b1) +
                       (a1 - a2) * special.digamma(a1) +
                       (b1 - b2) * special.digamma(b1) +
                       (a2 - a1 + b2 - b1) * special.digamma(a1 + b1))

        for dist1 in [d1, d1_sp]:
          for dist2 in [d2, d2_sp]:
            kl = kullback_leibler.kl(dist1, dist2)
            kl_val = sess.run(kl)
            self.assertEqual(kl.get_shape(), shape)
            self.assertAllClose(kl_val, kl_expected)

        # Make sure KL(d1||d1) is 0
        kl_same = sess.run(kullback_leibler.kl(d1, d1))
        self.assertAllClose(kl_same, np.zeros_like(kl_expected))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号