def HistoryWrapper(steps):
class _HistoryWrapper(gym.Wrapper):
"""
Track history of observations for given amount of steps
Initial steps are zero-filled
"""
def __init__(self, env):
super(_HistoryWrapper, self).__init__(env)
self.steps = steps
self.history = self._make_history()
self.observation_space = self._make_observation_space(steps, env.observation_space)
@staticmethod
def _make_observation_space(steps, orig_obs):
assert isinstance(orig_obs, gym.spaces.Box)
low = np.repeat(np.expand_dims(orig_obs.low, 0), steps, axis=0)
high = np.repeat(np.expand_dims(orig_obs.high, 0), steps, axis=0)
return gym.spaces.Box(low, high)
def _make_history(self, last_item = None):
size = self.steps if last_item is None else self.steps-1
res = collections.deque([np.zeros(shape=self.env.observation_space.shape)] * size)
if last_item is not None:
res.append(last_item)
return res
def _step(self, action):
obs, reward, done, info = self.env.step(action)
self.history.popleft()
self.history.append(obs)
return self.history, reward, done, info
def _reset(self):
self.history = self._make_history(last_item=self.env.reset())
return self.history
return _HistoryWrapper
评论列表
文章目录