mixture_test.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def testProbBatchMultivariate(self):
    with self.test_session() as sess:
      dist = make_multivariate_mixture(
          batch_shape=[2, 3], num_components=2, event_shape=[4])

      for x in [
          np.random.randn(2, 3, 4).astype(np.float32),
          np.random.randn(4, 2, 3, 4).astype(np.float32)
      ]:
        p_x = dist.prob(x)
        self.assertEqual(x.shape[:-1], p_x.get_shape())

        cat_probs = nn_ops.softmax(dist.cat.logits)
        dist_probs = [d.prob(x) for d in dist.components]

        p_x_value, cat_probs_value, dist_probs_value = sess.run(
            [p_x, cat_probs, dist_probs])
        self.assertEqual(x.shape[:-1], p_x_value.shape)

        cat_probs_value = _swap_first_last_axes(cat_probs_value)
        total_prob = sum(c_p_value * d_p_value
                         for (c_p_value, d_p_value
                             ) in zip(cat_probs_value, dist_probs_value))

        self.assertAllClose(total_prob, p_x_value)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号