run_dqn_atari.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:rl_algorithms 作者: DanielTakeshi 项目源码 文件源码
def main():
    # Games that we'll be testing.
    game_to_ID = {'BeamRider':0,
                  'Breakout':1,
                  'Enduro':2,
                  'Pong':3,
                  'Qbert':4}

    # Get some arguments here. Note: num_timesteps default uses tasks default.
    parser = argparse.ArgumentParser()
    parser.add_argument('--game', type=str, default='Pong')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--num_timesteps', type=int, default=40000000)
    args = parser.parse_args()

    # Choose the game to play and set log file.
    benchmark = gym.benchmark_spec('Atari40M')
    task = benchmark.tasks[game_to_ID[args.game]]
    log_name = args.game+"_s"+str(args.seed).zfill(3)+".pkl"

    # Run training. Should change the seed if possible!
    # Also, the actual # of iterations run is _roughly_ num_timesteps/4.
    seed = args.seed
    env = get_env(task, seed)
    session = get_session()
    print("task = {}".format(task))
    atari_learn(env, 
                session, 
                num_timesteps=args.num_timesteps,
                log_file=log_name)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号