test_integration_dqn.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def test_dependent_environment():
    environment = gym.make('DependentEnv-v0')

    total_rewards = []
    for i in range(10):
        agent = dqn.Agent(
            dqn.space_shape(environment.observation_space),
            environment.action_space,
            model=model,
            epsilon=dqn.decay_generator(start=1.0, decay=0.995, minimum=0.1),
            gamma=0.99,
            learning_rate=0.1,
        )

        rl_loop.rl_loop_train(environment, agent, episodes=10)

        total_rewards.append(
            rl_loop.evaluate_single_episode(environment, agent)
        )

    # most of these 10 agents will be able to converge to the perfect policy
    assert np.mean(np.array(total_rewards) == 100) >= 0.5
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号