mixture_test.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def testProbScalarUnivariate(self):
    with self.test_session() as sess:
      dist = make_univariate_mixture(batch_shape=[], num_components=2)
      for x in [
          np.array(
              [1.0, 2.0], dtype=np.float32), np.array(
                  1.0, dtype=np.float32),
          np.random.randn(3, 4).astype(np.float32)
      ]:
        p_x = dist.prob(x)

        self.assertEqual(x.shape, p_x.get_shape())
        cat_probs = nn_ops.softmax([dist.cat.logits])[0]
        dist_probs = [d.prob(x) for d in dist.components]

        p_x_value, cat_probs_value, dist_probs_value = sess.run(
            [p_x, cat_probs, dist_probs])
        self.assertEqual(x.shape, p_x_value.shape)

        total_prob = sum(c_p_value * d_p_value
                         for (c_p_value, d_p_value
                             ) in zip(cat_probs_value, dist_probs_value))

        self.assertAllClose(total_prob, p_x_value)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号