main_tf_parallel.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:trpo 作者: jjkke88 项目源码 文件源码
def main(_):
    ps_hosts = FLAGS.ps_hosts.split(',')
    worker_hosts = FLAGS.worker_hosts.split(',')

    # Create a cluster from the parameter server and worker hosts.
    cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})

    # Create and start a server for the local task.
    # ???????
    # ??????task_index ???????
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.1 / 3.0)
    server = tf.train.Server(cluster,
                           job_name=FLAGS.job_name,
                           task_index=FLAGS.task_index,
                            config=tf.ConfigProto(gpu_options=gpu_options))

    if FLAGS.job_name == "ps":
        server.join()
    elif FLAGS.job_name == "worker":
        # ?op ????????worker?
        env = Environment(gym.make(pms.environment_name))
        with tf.device(tf.train.replica_device_setter(
            worker_device="/job:worker/task:%d" % (FLAGS.task_index),
            cluster=cluster)):
            agent = TRPOAgentParallel(env)
            saver = tf.train.Saver(max_to_keep=10)
            init_op = tf.initialize_all_variables()
            summary_op = tf.merge_all_summaries()
        # Create a "supervisor", which oversees the training process.
        sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0),
                             logdir="./checkpoint_parallel",
                             init_op=init_op,
                             global_step=agent.global_step,
                             saver=saver,
                             summary_op=None,
                             save_model_secs=60)

        # The supervisor takes care of session initialization, restoring from
        # a checkpoint, and closing when done or an error occurs.
        with sv.managed_session(server.target) as sess:
            agent.session = sess
            agent.gf.session = sess
            agent.sff.session =sess
            agent.supervisor = sv

            if pms.train_flag:
                agent.learn()
            elif FLAGS.task_index == 0:
                agent.test(pms.checkpoint_file)
        # Ask for all the services to stop.
        sv.stop()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号