def test_session_run(self):
with self.test_session(use_gpu=True) as sess:
samples = tf.constant([1, 2, 3])
log_probs = Mock()
probs = Mock()
sample_func = Mock(return_value=samples)
log_prob_func = Mock(return_value=log_probs)
prob_func = Mock(return_value=probs)
distribution = Mock(sample=sample_func,
log_prob=log_prob_func,
prob=prob_func,
dtype=tf.int32)
# test session.run
t = StochasticTensor('t', distribution, 1, samples)
self.assertAllEqual(sess.run(t), np.asarray([1, 2, 3]))
# test using as feed dict
self.assertAllEqual(
sess.run(tf.identity(t), feed_dict={
t: np.asarray([4, 5, 6])
}),
np.asarray([4, 5, 6])
)
评论列表
文章目录