run.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:meta-learning 作者: ioanachelu 项目源码 文件源码
def run():
    recreate_directory_structure()
    tf.reset_default_graph()

    sess = tf.Session()
    # sess = tf_debug.LocalCLIDebugWrapperSession(sess)
    # sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
    with sess:
        with tf.device("/cpu:0"):
            global_step = tf.Variable(0, dtype=tf.int32, name='global_episodes', trainable=False)
            optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.lr)


            # num_agents = multiprocessing.cpu_count()
            num_agents = FLAGS.nb_concurrent
            agents = []
            envs = []

            for i in range(num_agents):
                gym_env = gym.make(FLAGS.game)
                # if FLAGS.monitor:
                #     gym_env = gym.wrappers.Monitor(gym_env, FLAGS.experiments_dir + '/worker_{}'.format(i), force=True)
                if FLAGS.game not in flags.SUPPORTED_ENVS:
                    gym_env = atari_environment.AtariEnvironment(gym_env=gym_env, resized_width=FLAGS.resized_width,
                                                                 resized_height=FLAGS.resized_height,
                                                                 agent_history_length=FLAGS.agent_history_length)
                    FLAGS.nb_actions = len(gym_env.gym_actions)

                envs.append(gym_env)

            global_network = FUNNetwork('global', None)

            for i in range(num_agents):
                agents.append(Agent(envs[i], i, optimizer, global_step))
            saver = tf.train.Saver(max_to_keep=5)

        coord = tf.train.Coordinator()
        if FLAGS.resume:
            ckpt = tf.train.get_checkpoint_state(os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name))
            print("Loading Model from {}".format(ckpt.model_checkpoint_path))
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            sess.run(tf.global_variables_initializer())

        agent_threads = []
        for agent in agents:
            thread = threading.Thread(target=(lambda: agent.play(sess, coord, saver)))
            thread.start()
            agent_threads.append(thread)

        while True:
            if FLAGS.show_training:
                for env in envs:
                    # time.sleep(1)
                    # with main_lock:
                    env.render()

        coord.join(agent_threads)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号