loss.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:relaax 作者: deeplearninc 项目源码 文件源码
def build_graph(self, q_network, config):
        self.ph_reward = tf.placeholder(tf.float32, [None])
        self.ph_action = tf.placeholder(tf.int32, [None])
        self.ph_terminal = tf.placeholder(tf.int32, [None])
        self.ph_q_next_target = tf.placeholder(tf.float32, [None, config.output.action_size])
        self.ph_q_next = tf.placeholder(tf.float32, [None, config.output.action_size])

        action_one_hot = tf.one_hot(self.ph_action, config.output.action_size)
        q_action = tf.reduce_sum(tf.multiply(q_network.node, action_one_hot), axis=1)

        if config.double_dqn:
            q_max = tf.reduce_sum(self.ph_q_next_target * tf.one_hot(tf.argmax(self.ph_q_next, axis=1),
                                                                     config.output.action_size), axis=1)
        else:
            q_max = tf.reduce_max(self.ph_q_next_target, axis=1)

        y = self.ph_reward + tf.cast(1 - self.ph_terminal, tf.float32) * tf.scalar_mul(config.rewards_gamma, q_max)

        return tf.losses.absolute_difference(q_action, y)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号