agent.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:DRLModule 作者: halleanwoo 项目源码 文件源码
def train_priority(self , state , reward , action , state_next , done, batch_ISweight):
        q , q_target = self.sess.run([self.q_value , self.q_target] ,
                                     feed_dict={self.inputs_q : state , self.inputs_target : state_next } )
        # DoubleDQN
        if self.double:
            q_next = self.sess.run(self.q_value , feed_dict={self.inputs_q : state_next})
            action_best = np.argmax(q_next , axis = 1)
            q_target_best = self.sess.run(self.q_target_action , feed_dict={self.action_best : action_best,
                                                                            self.q_target : q_target})
        else:
            q_target_best = np.max(q_target , axis = 1)   # dqn

        q_target_best_mask = ( 1.0 - done) * q_target_best
        target = reward + self.gamma * q_target_best_mask
        batch_ISweight = np.stack([batch_ISweight , batch_ISweight] , axis = -1 )
        loss, td_error, _ = self.sess.run([self.loss , self.td_error, self.train_op] ,
                                 feed_dict={self.inputs_q: state , self.target:target , self.action:action, self.ISweight : batch_ISweight ,} ) 
        return td_error
        # self.loss_his.append(loss)


# ===============================================================
#                             A3C Agent
# ===============================================================
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号