def make_env(env_id, seed, rank, log_dir):
def _thunk():
env = gym.make(env_id)
is_atari = hasattr(gym.envs, 'atari') and isinstance(env.unwrapped, gym.envs.atari.atari_env.AtariEnv)
if is_atari:
env = make_atari(env_id)
env.seed(seed + rank)
if log_dir is not None:
env = bench.Monitor(env, os.path.join(log_dir, str(rank)))
if is_atari:
env = wrap_deepmind(env)
# If the input has shape (W,H,3), wrap for PyTorch convolutions
obs_shape = env.observation_space.shape
if len(obs_shape) == 3 and obs_shape[2] in [1, 3]:
env = WrapPyTorch(env)
return env
return _thunk
评论列表
文章目录