def run_task(*_):
trpo_stepsize = 0.01
trpo_subsample_factor = 0.2
env = PointGatherEnv(apple_reward=10,bomb_cost=1,n_apples=2, activity_range=6)
policy = GaussianMLPPolicy(env.spec,
hidden_sizes=(64,32)
)
baseline = GaussianMLPBaseline(
env_spec=env.spec,
regressor_args={
'hidden_sizes': (64,32),
'hidden_nonlinearity': NL.tanh,
'learn_std':False,
'step_size':trpo_stepsize,
'optimizer':ConjugateGradientOptimizer(subsample_factor=trpo_subsample_factor)
}
)
safety_baseline = GaussianMLPBaseline(
env_spec=env.spec,
regressor_args={
'hidden_sizes': (64,32),
'hidden_nonlinearity': NL.tanh,
'learn_std':False,
'step_size':trpo_stepsize,
'optimizer':ConjugateGradientOptimizer(subsample_factor=trpo_subsample_factor)
},
target_key='safety_returns',
)
safety_constraint = GatherSafetyConstraint(max_value=0.1, baseline=safety_baseline)
algo = CPO(
env=env,
policy=policy,
baseline=baseline,
safety_constraint=safety_constraint,
safety_gae_lambda=1,
batch_size=50000,
max_path_length=15,
n_itr=100,
gae_lambda=0.95,
discount=0.995,
step_size=trpo_stepsize,
optimizer_args={'subsample_factor':trpo_subsample_factor},
#plot=True,
)
algo.train()
评论列表
文章目录