def train(env_id, num_timesteps, seed):
import baselines.common.tf_util as U
sess = U.single_threaded_session()
sess.__enter__()
rank = MPI.COMM_WORLD.Get_rank()
if rank != 0:
logger.set_level(logger.DISABLED)
workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
set_global_seeds(workerseed)
env = gym.make(env_id)
def policy_fn(name, ob_space, ac_space):
return MlpPolicy(name=name, ob_space=env.observation_space, ac_space=env.action_space,
hid_size=32, num_hid_layers=2)
env = bench.Monitor(env, logger.get_dir() and
osp.join(logger.get_dir(), str(rank)))
env.seed(workerseed)
gym.logger.setLevel(logging.WARN)
trpo_mpi.learn(env, policy_fn, timesteps_per_batch=1024, max_kl=0.01, cg_iters=10, cg_damping=0.1,
max_timesteps=num_timesteps, gamma=0.99, lam=0.98, vf_iters=5, vf_stepsize=1e-3)
env.close()
评论列表
文章目录