mixture_test.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def testEntropyLowerBoundMultivariate(self):
    with self.test_session() as sess:
      for batch_shape in ((), (2,), (2, 3)):
        dist = make_multivariate_mixture(
            batch_shape=batch_shape, num_components=2, event_shape=(4,))
        entropy_lower_bound = dist.entropy_lower_bound()
        self.assertEqual(batch_shape, entropy_lower_bound.get_shape())

        cat_probs = nn_ops.softmax(dist.cat.logits)
        dist_entropy = [d.entropy() for d in dist.components]

        entropy_lower_bound_value, cat_probs_value, dist_entropy_value = (
            sess.run([entropy_lower_bound, cat_probs, dist_entropy]))
        self.assertEqual(batch_shape, entropy_lower_bound_value.shape)

        cat_probs_value = _swap_first_last_axes(cat_probs_value)

        # entropy_lower_bound = sum_i pi_i entropy_i
        # for i in num_components, batchwise.
        true_entropy_lower_bound = sum(
            [c_p * m for (c_p, m) in zip(cat_probs_value, dist_entropy_value)])

        self.assertAllClose(true_entropy_lower_bound, entropy_lower_bound_value)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号