def testEntropyLowerBoundMultivariate(self):
with self.test_session() as sess:
for batch_shape in ((), (2,), (2, 3)):
dist = make_multivariate_mixture(
batch_shape=batch_shape, num_components=2, event_shape=(4,))
entropy_lower_bound = dist.entropy_lower_bound()
self.assertEqual(batch_shape, entropy_lower_bound.get_shape())
cat_probs = nn_ops.softmax(dist.cat.logits)
dist_entropy = [d.entropy() for d in dist.components]
entropy_lower_bound_value, cat_probs_value, dist_entropy_value = (
sess.run([entropy_lower_bound, cat_probs, dist_entropy]))
self.assertEqual(batch_shape, entropy_lower_bound_value.shape)
cat_probs_value = _swap_first_last_axes(cat_probs_value)
# entropy_lower_bound = sum_i pi_i entropy_i
# for i in num_components, batchwise.
true_entropy_lower_bound = sum(
[c_p * m for (c_p, m) in zip(cat_probs_value, dist_entropy_value)])
self.assertAllClose(true_entropy_lower_bound, entropy_lower_bound_value)
mixture_test.py 文件源码
python
阅读 18
收藏 0
点赞 0
评论 0
评论列表
文章目录