main.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:ViZDoomAgents 作者: GoingMyWay 项目源码 文件源码
def main_train(tf_configs=None):
    s_t = time.time()

    tf.reset_default_graph()

    if not os.path.exists(cfg.model_path):
        os.makedirs(cfg.model_path)

    global_episodes = tf.Variable(0, dtype=tf.int32, name='global_episodes', trainable=False)
    with tf.device("/gpu:0"):
        optimizer = tf.train.RMSPropOptimizer(learning_rate=1e-5)
        global_network = network.ACNetwork('global', optimizer, img_shape=cfg.IMG_SHAPE)
        num_workers = cfg.AGENTS_NUM
        agents = []
        # Create worker classes
        for i in range(num_workers):
            agents.append(agent.Agent(DoomGame(), i, optimizer, cfg.model_path, global_episodes, task_name='D3_battle'))
    saver = tf.train.Saver(max_to_keep=100)

    with tf.Session(config=tf_configs) as sess:
        coord = tf.train.Coordinator()
        if load_model:
            print('Loading Model...')
            ckpt = tf.train.get_checkpoint_state(cfg.model_path)
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            sess.run(tf.global_variables_initializer())

        # This is where the asynchronous magic happens.
        # Start the "work" process for each worker in a separate threat.
        worker_threads = []
        for ag in agents:
            agent_train = lambda: ag.train_a3c(max_episode_length, gamma, sess, coord, saver)
            t = threading.Thread(target=(agent_train))
            t.start()
            time.sleep(0.5)
            worker_threads.append(t)
        coord.join(worker_threads)
    print("training ends, costs{}".format(time.time() - s_t))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号