def get_action(self, observation):
if self.state_include_action:
if self.prev_action is None:
prev_action = np.zeros((self.action_space.flat_dim,))
else:
prev_action = self.action_space.flatten(self.prev_action)
all_input = np.concatenate([
self.observation_space.flatten(observation),
prev_action
])
else:
all_input = self.observation_space.flatten(observation)
# should not be used
prev_action = np.nan
probs, hidden_vec = [x[0] for x in self.f_step_prob([all_input], [self.prev_hidden])]
action = special.weighted_sample(probs, range(self.action_space.n))
self.prev_action = action
self.prev_hidden = hidden_vec
agent_info = dict(prob=probs)
if self.state_include_action:
agent_info["prev_action"] = prev_action
return action, agent_info
评论列表
文章目录