def train_priority(self , state , reward , action , state_next , done, batch_ISweight):
q , q_target = self.sess.run([self.q_value , self.q_target] ,
feed_dict={self.inputs_q : state , self.inputs_target : state_next } )
# DoubleDQN
if self.double:
q_next = self.sess.run(self.q_value , feed_dict={self.inputs_q : state_next})
action_best = np.argmax(q_next , axis = 1)
q_target_best = self.sess.run(self.q_target_action , feed_dict={self.action_best : action_best,
self.q_target : q_target})
else:
q_target_best = np.max(q_target , axis = 1) # dqn
q_target_best_mask = ( 1.0 - done) * q_target_best
target = reward + self.gamma * q_target_best_mask
batch_ISweight = np.stack([batch_ISweight , batch_ISweight] , axis = -1 )
loss, td_error, _ = self.sess.run([self.loss , self.td_error, self.train_op] ,
feed_dict={self.inputs_q: state , self.target:target , self.action:action, self.ISweight : batch_ISweight ,} )
return td_error
# self.loss_his.append(loss)
# ===============================================================
# A3C Agent
# ===============================================================
评论列表
文章目录