def __init__(self, env, obs_stack):
super(ObservationStackWrap, self).__init__(env=env)
assert obs_stack > 1, "Observation stack length must be higher than 1."
assert not isinstance(self.observation_space, Tuple),\
"Observation stack is not compatible with Tuple spaces."
self._obs_stack_len = obs_stack or 1
self.observation_space = self.env.observation_space
new_shape = list(self.observation_space.shape)
new_shape[-1] = self.observation_space.shape[-1] * obs_stack
self.observation_space.reshape(tuple(new_shape))
self._obs_stack = None
评论列表
文章目录