a3c.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:reinforcement-learning 作者: cgnicholls 项目源码 文件源码
def a3c(game_name, num_threads=8, restore=None, save_path='model'):
    processes = []
    envs = []
    for _ in range(num_threads+1):
        gym_env = gym.make(game_name)
        if game_name == 'CartPole-v0':
            env = CustomGymClassicControl(game_name)
        else:
            print "Assuming ATARI game and playing with pixels"
            env = CustomGym(game_name)
        envs.append(env)

    # Separate out the evaluation environment
    evaluation_env = envs[0]
    envs = envs[1:]

    with tf.Session() as sess:
        agent = Agent(session=sess,
        action_size=envs[0].action_size, model='mnih',
        optimizer=tf.train.AdamOptimizer(INITIAL_LEARNING_RATE))

        # Create a saver, and only keep 2 checkpoints.
        saver = tf.train.Saver(max_to_keep=2)

        T_queue = Queue.Queue()

        # Either restore the parameters or don't.
        if restore is not None:
            saver.restore(sess, save_path + '-' + str(restore))
            last_T = restore
            print "T was:", last_T
            T_queue.put(last_T)
        else:
            sess.run(tf.global_variables_initializer())
            T_queue.put(0)

        summary = Summary(save_path, agent)

        # Create a process for each worker
        for i in range(num_threads):
            processes.append(threading.Thread(target=async_trainer, args=(agent,
            envs[i], sess, i, T_queue, summary, saver, save_path,)))

        # Create a process to evaluate the agent
        processes.append(threading.Thread(target=evaluator, args=(agent,
        evaluation_env, sess, T_queue, summary, saver, save_path,)))

        # Start all the processes
        for p in processes:
            p.daemon = True
            p.start()

        # Until training is finished
        while not training_finished:
            sleep(0.01)

        # Join the processes, so we get this thread back.
        for p in processes:
            p.join()

# Returns sum(rewards[i] * gamma**i)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号