def get_supervisor(model):
saver = tf.train.Saver()
summary_writer = tf.summary.FileWriter(FLAGS.model_dir)
supervisor = tf.train.Supervisor(
logdir=FLAGS.model_dir,
is_chief=True,
saver=saver,
init_op=set_initial_ops(),
summary_op=tf.summary.merge_all(),
summary_writer=summary_writer,
save_summaries_secs=100, # TODO: add as flags
save_model_secs=1000,
global_step=model.global_step,
)
return supervisor
python类flags()的实例源码
def register_dataset_flags():
logging.info("Registering Dataset flags")
flags.DEFINE_integer("batch_size", 128,
"Size of the batch of the dataset iterator.")
flags.DEFINE_integer("buffer_size", 512,
"Size of the buffer of the dataset iterator.")
flags.DEFINE_integer("take_count", -1,
"Creates a `Dataset` with at most `count` batches from this dataset.")
flags.DEFINE_string("train_subdir", "train",
"Location of training TFRecords, with the training set dir.")
flags.DEFINE_string("eval_subdir", "eval",
"Location of eval TFRecords, with the training set dir.")
def get_sess_config():
# gpu_options = tf.GPUOptions(
# per_process_gpu_memory_fraction=self.gpu_memory_fraction,
# allow_growth=True) # seems to be not working
sess_config = tf.ConfigProto(
# log_device_placement=True,
inter_op_parallelism_threads=8, # TODO: add as flags
# allow_soft_placement=True,
# gpu_options=gpu_options)
)
return sess_config
def run_benchmark_distributed():
ops = create_graph("/job:worker/task:0", "/job:worker/task:1")
queues = [create_done_queue(0), create_done_queue(1)]
# launch distributed service
port0, port1 = [portpicker.pick_unused_port() for _ in range(2)]
flags = " ".join(sys.argv) # pass parent flags to children
def run_worker(w):
my_env = os.environ.copy()
if not FLAGS.verbose:
my_env["CUDA_VISIBLE_DEVICES"] = ""
my_env["TF_CPP_MIN_LOG_LEVEL"] = "2"
if FLAGS.profile:
my_env["LD_PRELOAD"]="/usr/lib/libtcmalloc_and_profiler.so.4"
my_env["CPUPROFILE"]="/tmp/profile.out.%s"%(w)
cmd = "python %s --task=%d --port0=%s --port1=%s"%(flags, w, port0, port1)
subprocess.Popen(cmd, shell=True, stderr=subprocess.STDOUT,
env=my_env)
run_worker(0)
run_worker(1)
sess = tf.Session("grpc://%s:%s"%(host, port0), config=session_config())
rate = run_benchmark(sess, *ops)
# bring down workers
if FLAGS.verbose:
print("Killing workers.")
sess.run(queues[1].enqueue(1))
sess.run(queues[0].enqueue(1)) # bring down master last
return rate
def run_benchmark_distributed():
ops = create_graph("/job:worker/task:0", "/job:worker/task:1")
queues = [create_done_queue(0), create_done_queue(1)]
# launch distributed service
port0, port1 = [portpicker.pick_unused_port() for _ in range(2)]
flags = " ".join(sys.argv) # pass parent flags to children
def run_worker(w):
my_env = os.environ.copy()
if not FLAGS.verbose:
my_env["CUDA_VISIBLE_DEVICES"] = ""
my_env["TF_CPP_MIN_LOG_LEVEL"] = "2"
if FLAGS.profile:
my_env["LD_PRELOAD"]="/usr/lib/libtcmalloc_and_profiler.so.4"
my_env["CPUPROFILE"]="/tmp/profile.out.%s"%(w)
cmd = "python %s --task=%d --port0=%s --port1=%s"%(flags, w, port0, port1)
subprocess.Popen(cmd, shell=True, stderr=subprocess.STDOUT,
env=my_env)
run_worker(0)
run_worker(1)
sess = tf.Session("grpc://%s:%s"%(host, port0), config=session_config())
rate = run_benchmark(sess, *ops)
# bring down workers
if FLAGS.verbose:
print("Killing workers.")
sess.run(queues[1].enqueue(1))
# todo: sleep to avoid killing master too early?
sess.run(queues[0].enqueue(1)) # bring down master last
return rate
def register_core_flags():
logging.info("Registering core spotify-tensorflow flags")
flags.DEFINE_string("training_set", None,
"Location of the training set")
flags.DEFINE_string("job-dir", None,
"Where to write data")
def restore_config(config_file):
with open(config_file) as f:
flags = pickle.load(f)
for k, v in flags.iteritems():
setattr(FLAGS, k, v)
def save_config(config_file):
with open(config_file, 'w') as f:
flags = get_flags()
saved_flags = {}
for k in _SAVE_FLAGS:
saved_flags[k] = flags[k]
pickle.dump(saved_flags, f)
def restore_config(config_file):
with open(config_file) as f:
flags = pickle.load(f)
for k, v in flags.iteritems():
setattr(FLAGS, k, v)
def save_config(config_file):
with open(config_file, 'w') as f:
flags = get_flags()
saved_flags = {}
for k in _SAVE_FLAGS:
saved_flags[k] = flags[k]
pickle.dump(saved_flags, f)
def vocab_size(self):
return self._vocab_size
# Define flags from the t2t binaries
def __init__(self, data_dir, model_dir):
"""Creates the Transformer estimator.
Args:
data_dir: The training data directory.
model_dir: The trained model directory.
"""
# Do the pre-setup tensor2tensor requires for flags and configurations.
FLAGS.output_dir = model_dir
FLAGS.data_dir = data_dir
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
data_dir = os.path.expanduser(data_dir)
# Create the basic hyper parameters.
self.hparams = tpu_trainer_lib.create_hparams(
FLAGS.hparams_set,
FLAGS.hparams,
data_dir=data_dir,
problem_name=FLAGS.problems)
decode_hp = decoding.decode_hparams(FLAGS.decode_hparams)
decode_hp.add_hparam("shards", 1)
decode_hp.add_hparam("shard_id", 0)
# Create the estimator and final hyper parameters.
self.estimator = tpu_trainer_lib.create_estimator(
FLAGS.model,
self.hparams,
tpu_trainer.create_run_config(),
decode_hp, use_tpu=False)
# Fetch the vocabulary and other helpful variables for decoding.
self.source_vocab = self.hparams.problems[0].vocabulary["inputs"]
self.targets_vocab = self.hparams.problems[0].vocabulary["targets"]
self.const_array_size = 10000
# Prepare the Transformer's debug data directory.
run_dirs = sorted(glob.glob(os.path.join("/tmp/t2t_server_dump", "run_*")))
for run_dir in run_dirs:
shutil.rmtree(run_dir)