def run(args, server):
env = create_env(args.env_id, client_id=str(args.task), remotes=args.remotes)
trainer = A3C(env, args.task)
# Variable names that start with "local" are not saved in checkpoints.
variables_to_save = [v for v in tf.all_variables() if not v.name.startswith("local")]
init_op = tf.initialize_variables(variables_to_save)
init_all_op = tf.initialize_all_variables()
saver = FastSaver(variables_to_save)
def init_fn(ses):
logger.info("Initializing all parameters.")
ses.run(init_all_op)
config = tf.ConfigProto(device_filters=["/job:ps", "/job:worker/task:{}/cpu:0".format(args.task)])
logdir = os.path.join(args.log_dir, 'train')
summary_writer = tf.train.SummaryWriter(logdir + "_%d" % args.task)
logger.info("Events directory: %s_%s", logdir, args.task)
sv = tf.train.Supervisor(is_chief=(args.task == 0),
logdir=logdir,
saver=saver,
summary_op=None,
init_op=init_op,
init_fn=init_fn,
summary_writer=summary_writer,
ready_op=tf.report_uninitialized_variables(variables_to_save),
global_step=trainer.global_step,
save_model_secs=30,
save_summaries_secs=30)
num_global_steps = 100000000
logger.info(
"Starting session. If this hangs, we're mostly likely waiting to connect to the parameter server. " +
"One common cause is that the parameter server DNS name isn't resolving yet, or is misspecified.")
with sv.managed_session(server.target, config=config) as sess, sess.as_default():
trainer.start(sess, summary_writer)
global_step = sess.run(trainer.global_step)
logger.info("Starting training at step=%d", global_step)
while not sv.should_stop() and (not num_global_steps or global_step < num_global_steps):
trainer.process(sess)
global_step = sess.run(trainer.global_step)
# Ask for all the services to stop.
sv.stop()
logger.info('reached %s steps. worker stopped.', global_step)
评论列表
文章目录