def __init__(self, env, monitor_path, video=True, **usercfg):
super(CEM, self).__init__(**usercfg)
self.env = wrappers.Monitor(env, monitor_path, force=True, video_callable=(None if video else False))
self.config.update(dict(
num_steps=env.spec.tags.get("wrapper_config.TimeLimit.max_episode_steps"), # maximum length of episode
n_iter=100, # number of iterations of CEM
batch_size=25, # number of samples per batch
elite_frac=0.2 # fraction of samples used as elite set
))
self.config.update(usercfg)
if isinstance(env.action_space, Discrete):
self.dim_theta = (env.observation_space.shape[0] + 1) * env.action_space.n
elif isinstance(env.action_space, Box):
self.dim_theta = (env.observation_space.shape[0] + 1) * env.action_space.shape[0]
else:
raise NotImplementedError
# Initialize mean and standard deviation
self.theta_mean = np.zeros(self.dim_theta)
self.theta_std = np.ones(self.dim_theta)
评论列表
文章目录