def get_preprocessor(cls, env, options=dict()):
"""Returns a suitable processor for the given environment.
Args:
env (gym.Env): The gym environment to preprocess.
options (dict): Options to pass to the preprocessor.
Returns:
preprocessor (Preprocessor): Preprocessor for the env observations.
"""
# For older gym versions that don't set shape for Discrete
if not hasattr(env.observation_space, "shape") and \
isinstance(env.observation_space, gym.spaces.Discrete):
env.observation_space.shape = ()
env_name = env.spec.id
obs_shape = env.observation_space.shape
for k in options.keys():
if k not in MODEL_CONFIGS:
raise Exception(
"Unknown config key `{}`, all keys: {}".format(
k, MODEL_CONFIGS))
print("Observation shape is {}".format(obs_shape))
if env_name in cls._registered_preprocessor:
return cls._registered_preprocessor[env_name](
env.observation_space, options)
if obs_shape == ():
print("Using one-hot preprocessor for discrete envs.")
preprocessor = OneHotPreprocessor
elif obs_shape == cls.ATARI_OBS_SHAPE:
print("Assuming Atari pixel env, using AtariPixelPreprocessor.")
preprocessor = AtariPixelPreprocessor
elif obs_shape == cls.ATARI_RAM_OBS_SHAPE:
print("Assuming Atari ram env, using AtariRamPreprocessor.")
preprocessor = AtariRamPreprocessor
else:
print("Non-atari env, not using any observation preprocessor.")
preprocessor = NoPreprocessor
return preprocessor(env.observation_space, options)
评论列表
文章目录