defend_the_center.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:ViZDoomAgents 作者: GoingMyWay 项目源码 文件源码
def main_train(tf_configs=None):
    s_t = time.time()

    tf.reset_default_graph()

    if not os.path.exists(model_path):
        os.makedirs(model_path)

    with tf.device("/cpu:0"):
        global_episodes = tf.Variable(0, dtype=tf.int32, name='global_episodes', trainable=False)
        optimizer = tf.train.RMSPropOptimizer(learning_rate=1e-5)
        master_network = network.ACNetwork('global', optimizer)  # Generate global network
        num_workers = 16
        agents = []
        # Create worker classes
        for i in range(num_workers):
            agents.append(agent.Agent(DoomGame(), i, s_size, a_size, optimizer, model_path, global_episodes))
        saver = tf.train.Saver(max_to_keep=100)

    with tf.Session(config=tf_configs) as sess:
        coord = tf.train.Coordinator()
        if load_model:
            print('Loading Model...')
            ckpt = tf.train.get_checkpoint_state(model_path)
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            sess.run(tf.global_variables_initializer())

        # This is where the asynchronous magic happens.
        # Start the "work" process for each worker in a separate threat.
        worker_threads = []
        for ag in agents:
            agent_train = lambda: ag.train_a3c(max_episode_length, gamma, sess, coord, saver)
            t = threading.Thread(target=(agent_train))
            t.start()
            time.sleep(0.5)
            worker_threads.append(t)
        coord.join(worker_threads)
    print("training ends, costs{}".format(time.time() - s_t))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号