def map(self, state):
"""Compute output in session.
Make sure a default session is set when calling.
"""
state = state.flatten()
assert(self.state_space.contains(state))
if self.sess is None:
sess = tf.get_default_session()
else:
sess = self.sess
mean, var = sess.run([self.a_pred, self.var], {self.X: [state]})
action = np.array(normal(mean, var))
action = action.reshape(self.action_space.shape)
return action
评论列表
文章目录