def select_action(self, st, runstep=None):
with self.G.as_default():
if np.random.rand() > self.params['eps']:
#greedy with random tie-breaking
if not self.forward_only:
Q_pred = self.sess.run(self.qnet.y, feed_dict = {self.qnet.x: np.reshape(st, (1,84,84,4))})[0]
else:
Q_pred = runstep(self.sess, self.qnet.y, feed_dict = {self.qnet.x: np.reshape(st, (1,84,84,4))})[0]
a_winner = np.argwhere(Q_pred == np.amax(Q_pred))
if len(a_winner) > 1:
act_idx = a_winner[np.random.randint(0, len(a_winner))][0]
return act_idx,self.engine.legal_actions[act_idx], np.amax(Q_pred)
else:
act_idx = a_winner[0][0]
return act_idx,self.engine.legal_actions[act_idx], np.amax(Q_pred)
else:
#random
act_idx = np.random.randint(0,len(self.engine.legal_actions))
if not self.forward_only:
Q_pred = self.sess.run(self.qnet.y, feed_dict = {self.qnet.x: np.reshape(st, (1,84,84,4))})[0]
else:
Q_pred = runstep(self.sess, self.qnet.y, feed_dict = {self.qnet.x: np.reshape(st, (1,84,84,4))})[0]
return act_idx,self.engine.legal_actions[act_idx], Q_pred[act_idx]
评论列表
文章目录