cem.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:DeepRL 作者: arnomoonens 项目源码 文件源码
def __init__(self, env, monitor_path, video=True, **usercfg):
        super(CEM, self).__init__(**usercfg)
        self.env = wrappers.Monitor(env, monitor_path, force=True, video_callable=(None if video else False))
        self.config.update(dict(
            num_steps=env.spec.tags.get("wrapper_config.TimeLimit.max_episode_steps"),  # maximum length of episode
            n_iter=100,  # number of iterations of CEM
            batch_size=25,  # number of samples per batch
            elite_frac=0.2  # fraction of samples used as elite set
        ))
        self.config.update(usercfg)
        if isinstance(env.action_space, Discrete):
            self.dim_theta = (env.observation_space.shape[0] + 1) * env.action_space.n
        elif isinstance(env.action_space, Box):
            self.dim_theta = (env.observation_space.shape[0] + 1) * env.action_space.shape[0]
        else:
            raise NotImplementedError
        # Initialize mean and standard deviation
        self.theta_mean = np.zeros(self.dim_theta)
        self.theta_std = np.ones(self.dim_theta)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号