common.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:rl 作者: Shmuma 项目源码 文件源码
def HistoryWrapper(steps):
    class _HistoryWrapper(gym.Wrapper):
        """
        Track history of observations for given amount of steps
        Initial steps are zero-filled
        """
        def __init__(self, env):
            super(_HistoryWrapper, self).__init__(env)
            self.steps = steps
            self.history = self._make_history()
            self.observation_space = self._make_observation_space(steps, env.observation_space)

        @staticmethod
        def _make_observation_space(steps, orig_obs):
            assert isinstance(orig_obs, gym.spaces.Box)
            low = np.repeat(np.expand_dims(orig_obs.low, 0), steps, axis=0)
            high = np.repeat(np.expand_dims(orig_obs.high, 0), steps, axis=0)
            return gym.spaces.Box(low, high)

        def _make_history(self, last_item = None):
            size = self.steps if last_item is None else self.steps-1
            res = collections.deque([np.zeros(shape=self.env.observation_space.shape)] * size)
            if last_item is not None:
                res.append(last_item)
            return res

        def _step(self, action):
            obs, reward, done, info = self.env.step(action)
            self.history.popleft()
            self.history.append(obs)
            return self.history, reward, done, info

        def _reset(self):
            self.history = self._make_history(last_item=self.env.reset())
            return self.history

    return _HistoryWrapper
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号