def testMeanUnivariate(self):
with self.test_session() as sess:
for batch_shape in ((), (2,), (2, 3)):
dist = make_univariate_mixture(
batch_shape=batch_shape, num_components=2)
mean = dist.mean()
self.assertEqual(batch_shape, mean.get_shape())
cat_probs = nn_ops.softmax(dist.cat.logits)
dist_means = [d.mean() for d in dist.components]
mean_value, cat_probs_value, dist_means_value = sess.run(
[mean, cat_probs, dist_means])
self.assertEqual(batch_shape, mean_value.shape)
cat_probs_value = _swap_first_last_axes(cat_probs_value)
true_mean = sum(
[c_p * m for (c_p, m) in zip(cat_probs_value, dist_means_value)])
self.assertAllClose(true_mean, mean_value)
mixture_test.py 文件源码
python
阅读 21
收藏 0
点赞 0
评论 0
评论列表
文章目录