def testProbScalarUnivariate(self):
with self.test_session() as sess:
dist = make_univariate_mixture(batch_shape=[], num_components=2)
for x in [
np.array(
[1.0, 2.0], dtype=np.float32), np.array(
1.0, dtype=np.float32),
np.random.randn(3, 4).astype(np.float32)
]:
p_x = dist.prob(x)
self.assertEqual(x.shape, p_x.get_shape())
cat_probs = nn_ops.softmax([dist.cat.logits])[0]
dist_probs = [d.prob(x) for d in dist.components]
p_x_value, cat_probs_value, dist_probs_value = sess.run(
[p_x, cat_probs, dist_probs])
self.assertEqual(x.shape, p_x_value.shape)
total_prob = sum(c_p_value * d_p_value
for (c_p_value, d_p_value
) in zip(cat_probs_value, dist_probs_value))
self.assertAllClose(total_prob, p_x_value)
mixture_test.py 文件源码
python
阅读 19
收藏 0
点赞 0
评论 0
评论列表
文章目录