train_agent_chainer.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:gym-malware 作者: endgameinc 项目源码 文件源码
def create_acer_agent(env):
    obs_dim = env.observation_space.shape[0]
    n_actions = env.action_space.n

    model = acer.ACERSeparateModel(
        pi=links.Sequence(
            L.Linear( obs_dim, 1024, initialW=LeCunNormal(1e-3)),
            F.relu,
            L.Linear( 1024, 512, initialW=LeCunNormal(1e-3)),
            F.relu,
            L.Linear( 512, n_actions, initialW=LeCunNormal(1e-3)),
            SoftmaxDistribution),
        q=links.Sequence(
            L.Linear( obs_dim, 1024, initialW=LeCunNormal(1e-3)),
            F.relu,
            L.Linear( 1024, 512, initialW=LeCunNormal(1e-3)),
            F.relu,
            L.Linear( 512, n_actions, initialW=LeCunNormal(1e-3)),
            DiscreteActionValue),
        )

    opt = rmsprop_async.RMSpropAsync( lr=7e-4, eps=1e-2, alpha=0.99)
    opt.setup( model )
    opt.add_hook( chainer.optimizer.GradientClipping(40) )

    replay_buffer = EpisodicReplayBuffer( 128 )
    agent = acer.ACER( model, opt, 
        gamma=0.95, # reward discount factor
        t_max=32, # update the model after this many local steps
        replay_buffer=replay_buffer,
        n_times_replay=4, # number of times experience replay is repeated for each update
        replay_start_size=64, # don't start replay unless we have this many experiences in the buffer
        disable_online_update=True, # rely only on experience buffer
        use_trust_region=True,  # enable trust region policy optimiztion
        trust_region_delta=0.1,  # a parameter for TRPO
        truncation_threshold=5.0, # truncate large importance weights
        beta=1e-2, # entropy regularization parameter
        phi= lambda obs: obs.astype(np.float32, copy=False) )

    return agent
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号