def main(_):
ps_hosts = FLAGS.ps_hosts.split(',')
worker_hosts = FLAGS.worker_hosts.split(',')
# Create a cluster from the parameter server and worker hosts.
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
# Create and start a server for the local task.
# ???????
# ??????task_index ???????
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.1 / 3.0)
server = tf.train.Server(cluster,
job_name=FLAGS.job_name,
task_index=FLAGS.task_index,
config=tf.ConfigProto(gpu_options=gpu_options))
if FLAGS.job_name == "ps":
server.join()
elif FLAGS.job_name == "worker":
# ?op ????????worker?
env = Environment(gym.make(pms.environment_name))
with tf.device(tf.train.replica_device_setter(
worker_device="/job:worker/task:%d" % (FLAGS.task_index),
cluster=cluster)):
agent = TRPOAgentParallel(env)
saver = tf.train.Saver(max_to_keep=10)
init_op = tf.initialize_all_variables()
summary_op = tf.merge_all_summaries()
# Create a "supervisor", which oversees the training process.
sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0),
logdir="./checkpoint_parallel",
init_op=init_op,
global_step=agent.global_step,
saver=saver,
summary_op=None,
save_model_secs=60)
# The supervisor takes care of session initialization, restoring from
# a checkpoint, and closing when done or an error occurs.
with sv.managed_session(server.target) as sess:
agent.session = sess
agent.gf.session = sess
agent.sff.session =sess
agent.supervisor = sv
if pms.train_flag:
agent.learn()
elif FLAGS.task_index == 0:
agent.test(pms.checkpoint_file)
# Ask for all the services to stop.
sv.stop()
评论列表
文章目录