def eps_greedy(self, state_batch, exploration_rate):
if state_batch.ndim == 1:
state_batch = state_batch.reshape(1, -1)
elif state_batch.ndim == 3:
state_batch = state_batch.reshape(-1, 34 * config.rl_history_length)
prop = np.random.uniform()
if prop < exploration_rate:
action_batch = np.random.randint(0, len(config.actions), (state_batch.shape[0],))
q = None
else:
state_batch = Variable(state_batch)
if config.use_gpu:
state_batch.to_gpu()
q = self.compute_q_variable(state_batch, test=True)
if config.use_gpu:
q.to_cpu()
q = q.data
action_batch = np.argmax(q, axis=1)
for i in xrange(action_batch.shape[0]):
action_batch[i] = self.get_action_for_index(action_batch[i])
return action_batch, q
评论列表
文章目录