test_base.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:zhusuan 作者: thu-ml 项目源码 文件源码
def test_session_run(self):
        with self.test_session(use_gpu=True) as sess:
            samples = tf.constant([1, 2, 3])
            log_probs = Mock()
            probs = Mock()
            sample_func = Mock(return_value=samples)
            log_prob_func = Mock(return_value=log_probs)
            prob_func = Mock(return_value=probs)
            distribution = Mock(sample=sample_func,
                                log_prob=log_prob_func,
                                prob=prob_func,
                                dtype=tf.int32)

            # test session.run
            t = StochasticTensor('t', distribution, 1, samples)
            self.assertAllEqual(sess.run(t), np.asarray([1, 2, 3]))

            # test using as feed dict
            self.assertAllEqual(
                sess.run(tf.identity(t), feed_dict={
                    t: np.asarray([4, 5, 6])
                }),
                np.asarray([4, 5, 6])
            )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号