def env_factory(cmdl, mode):
# Undo the default logger and configure a new one.
gym.undo_logger_setup()
logger = logging.getLogger()
logger.setLevel(logging.WARNING)
print(clr("[Main] Constructing %s environment." % mode, attrs=['bold']))
env = gym.make(cmdl.env_name)
if hasattr(cmdl, 'rescale_dims'):
state_dims = (cmdl.rescale_dims, cmdl.rescale_dims)
else:
state_dims = env.observation_space.shape[0:2]
env_class, hist_len, cuda = cmdl.env_class, cmdl.hist_len, cmdl.cuda
if mode == "training":
env = PreprocessFrames(env, env_class, hist_len, state_dims, cuda)
if hasattr(cmdl, 'reward_clamp') and cmdl.reward_clamp:
env = SqueezeRewards(env)
if hasattr(cmdl, 'done_after_lost_life') and cmdl.done_after_lost_life:
env = DoneAfterLostLife(env)
print('-' * 50)
return env
elif mode == "evaluation":
if cmdl.eval_env_name != cmdl.env_name:
print(clr("[%s] Warning! evaluating on a different env: %s"
% ("Main", cmdl.eval_env_name), 'red', attrs=['bold']))
env = gym.make(cmdl.eval_env_name)
env = PreprocessFrames(env, env_class, hist_len, state_dims, cuda)
env = EvaluationMonitor(env, cmdl)
print('-' * 50)
return env
评论列表
文章目录