entropy_test.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def test_fitting_two_dimensional_normal_n_equals_1000(self):
    # Minmizing Renyi divergence should allow us to make one normal match
    # another one exactly.
    n = 1000
    mu_true = np.array([1.0, -1.0], dtype=np.float64)
    chol_true = np.array([[2.0, 0.0], [0.5, 1.0]], dtype=np.float64)
    with self.test_session() as sess:
      target = distributions.MultivariateNormalCholesky(mu_true, chol_true)

      # Set up q distribution by defining mean/covariance as Variables
      mu = variables.Variable(
          np.zeros(mu_true.shape), dtype=mu_true.dtype, name='mu')
      mat = variables.Variable(
          np.zeros(chol_true.shape), dtype=chol_true.dtype, name='mat')
      chol = distributions.matrix_diag_transform(mat, transform=nn_ops.softplus)
      q = distributions.MultivariateNormalCholesky(mu, chol)
      for alpha in [0.25, 0.75]:

        negative_renyi_divergence = entropy.renyi_ratio(
            log_p=target.log_prob, q=q, n=n, alpha=alpha, seed=0)
        train_op = get_train_op(
            math_ops.reduce_mean(-negative_renyi_divergence),
            optimizer='SGD',
            learning_rate=0.5,
            decay=0.1)

        variables.global_variables_initializer().run()
        renyis = []
        for step in range(1000):
          sess.run(train_op)
          if step in [1, 5, 100]:
            renyis.append(negative_renyi_divergence.eval())

        # This optimization should maximize the renyi divergence.
        _assert_monotonic_increasing(renyis, atol=0)

        # Relative tolerance (rtol) chosen 2 times as large as minimim needed to
        # pass.
        self.assertAllClose(target.mu.eval(), q.mu.eval(), rtol=0.06)
        self.assertAllClose(target.sigma.eval(), q.sigma.eval(), rtol=0.02)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号