catalog.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:ray 作者: ray-project 项目源码 文件源码
def get_preprocessor(cls, env, options=dict()):
        """Returns a suitable processor for the given environment.

        Args:
            env (gym.Env): The gym environment to preprocess.
            options (dict): Options to pass to the preprocessor.

        Returns:
            preprocessor (Preprocessor): Preprocessor for the env observations.
        """

        # For older gym versions that don't set shape for Discrete
        if not hasattr(env.observation_space, "shape") and \
                isinstance(env.observation_space, gym.spaces.Discrete):
            env.observation_space.shape = ()

        env_name = env.spec.id
        obs_shape = env.observation_space.shape

        for k in options.keys():
            if k not in MODEL_CONFIGS:
                raise Exception(
                    "Unknown config key `{}`, all keys: {}".format(
                        k, MODEL_CONFIGS))

        print("Observation shape is {}".format(obs_shape))

        if env_name in cls._registered_preprocessor:
            return cls._registered_preprocessor[env_name](
                    env.observation_space, options)

        if obs_shape == ():
            print("Using one-hot preprocessor for discrete envs.")
            preprocessor = OneHotPreprocessor
        elif obs_shape == cls.ATARI_OBS_SHAPE:
            print("Assuming Atari pixel env, using AtariPixelPreprocessor.")
            preprocessor = AtariPixelPreprocessor
        elif obs_shape == cls.ATARI_RAM_OBS_SHAPE:
            print("Assuming Atari ram env, using AtariRamPreprocessor.")
            preprocessor = AtariRamPreprocessor
        else:
            print("Non-atari env, not using any observation preprocessor.")
            preprocessor = NoPreprocessor

        return preprocessor(env.observation_space, options)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号