def build_graph(self, q_network, config):
self.ph_reward = tf.placeholder(tf.float32, [None])
self.ph_action = tf.placeholder(tf.int32, [None])
self.ph_terminal = tf.placeholder(tf.int32, [None])
self.ph_q_next_target = tf.placeholder(tf.float32, [None, config.output.action_size])
self.ph_q_next = tf.placeholder(tf.float32, [None, config.output.action_size])
action_one_hot = tf.one_hot(self.ph_action, config.output.action_size)
q_action = tf.reduce_sum(tf.multiply(q_network.node, action_one_hot), axis=1)
if config.double_dqn:
q_max = tf.reduce_sum(self.ph_q_next_target * tf.one_hot(tf.argmax(self.ph_q_next, axis=1),
config.output.action_size), axis=1)
else:
q_max = tf.reduce_max(self.ph_q_next_target, axis=1)
y = self.ph_reward + tf.cast(1 - self.ph_terminal, tf.float32) * tf.scalar_mul(config.rewards_gamma, q_max)
return tf.losses.absolute_difference(q_action, y)
评论列表
文章目录