ddpg.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:value_gradient 作者: rarilurelo 项目源码 文件源码
def build(self):
        model = self.net.model
        pi_model = self.net.pi_model
        q_model = self.net.q_model
        target_model = self.net.target_model
        target_pi_model = self.net.target_pi_model
        target_q_model = self.net.target_q_model

        self.states = tf.placeholder(tf.float32, shape=(None, self.in_dim), name='states')
        self.actions = tf.placeholder(tf.float32, shape=[None, self.action_dim], name='actions')
        self.rewards = tf.placeholder(tf.float32, shape=[None], name='rewards')
        self.next_states = tf.placeholder(tf.float32, shape=[None, self.in_dim], name='next_states')
        # terminal contain only 0 or 1 it will work as masking
        #self.terminals = tf.placeholder(tf.bool, shape=[None], name='terminals')
        self.ys = tf.placeholder(tf.float32, shape=[None])

        #y = tf.where(self.terminals, self.rewards, self.rewards + self.gamma * K.stop_gradient(K.sum(target_q_model(Concatenate()([target_model(self.next_states),
        #    target_pi_model(self.next_states)])), axis=-1)))
        self.target_q = K.sum(target_q_model(Concatenate()([target_model(self.states), target_pi_model(self.states)])), axis=-1)
        self.q = K.sum(q_model(Concatenate()([model(self.states), self.actions])), axis=-1)
        self.q_loss = K.mean(K.square(self.ys-self.q))

        self.mu = pi_model(self.states)
        self.pi_loss = - K.mean(q_model(Concatenate()([model(self.states), self.mu])))

        self.q_updater = self.q_optimizer.minimize(self.q_loss, var_list=self.net.var_q)
        self.pi_updater = self.pi_opimizer.minimize(self.pi_loss, var_list=self.net.var_pi)

        self.soft_updater = [K.update(t_p, t_p*(1-self.tau)+p*self.tau) for p, t_p in zip(self.net.var_all, self.net.var_target_all)]
        self.sync = [K.update(t_p, p) for p, t_p in zip(self.net.var_all, self.net.var_target_all)]

        self.sess.run(tf.global_variables_initializer())
        self.built = True
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号