mixture_test.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def testMeanUnivariate(self):
    with self.test_session() as sess:
      for batch_shape in ((), (2,), (2, 3)):
        dist = make_univariate_mixture(
            batch_shape=batch_shape, num_components=2)
        mean = dist.mean()
        self.assertEqual(batch_shape, mean.get_shape())

        cat_probs = nn_ops.softmax(dist.cat.logits)
        dist_means = [d.mean() for d in dist.components]

        mean_value, cat_probs_value, dist_means_value = sess.run(
            [mean, cat_probs, dist_means])
        self.assertEqual(batch_shape, mean_value.shape)

        cat_probs_value = _swap_first_last_axes(cat_probs_value)
        true_mean = sum(
            [c_p * m for (c_p, m) in zip(cat_probs_value, dist_means_value)])

        self.assertAllClose(true_mean, mean_value)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号