def testProbBatchMultivariate(self):
with self.test_session() as sess:
dist = make_multivariate_mixture(
batch_shape=[2, 3], num_components=2, event_shape=[4])
for x in [
np.random.randn(2, 3, 4).astype(np.float32),
np.random.randn(4, 2, 3, 4).astype(np.float32)
]:
p_x = dist.prob(x)
self.assertEqual(x.shape[:-1], p_x.get_shape())
cat_probs = nn_ops.softmax(dist.cat.logits)
dist_probs = [d.prob(x) for d in dist.components]
p_x_value, cat_probs_value, dist_probs_value = sess.run(
[p_x, cat_probs, dist_probs])
self.assertEqual(x.shape[:-1], p_x_value.shape)
cat_probs_value = _swap_first_last_axes(cat_probs_value)
total_prob = sum(c_p_value * d_p_value
for (c_p_value, d_p_value
) in zip(cat_probs_value, dist_probs_value))
self.assertAllClose(total_prob, p_x_value)
mixture_test.py 文件源码
python
阅读 24
收藏 0
点赞 0
评论 0
评论列表
文章目录