def run(target, is_chief, train_steps, job_dir, train_files, eval_files, num_epochs, learning_rate):
num_channels = 6
hooks = list()
# does not work well in distributed mode cause it only counts local steps (I think...)
hooks.append(tf.train.StopAtStepHook(train_steps))
if is_chief:
evaluation_graph = tf.Graph()
with evaluation_graph.as_default():
# Features and label tensors
image, ground_truth, name = model.input_fn(eval_files, 1, shuffle=False, shared_name=None)
# Returns dictionary of tensors to be evaluated
metric_dict = model.model_fn(model.EVAL, name, image, ground_truth, num_channels, learning_rate)
# hook that performs evaluation separate from training
hooks.append(EvaluationRunHook(job_dir, metric_dict, evaluation_graph))
hooks.append(CheckpointExporterHook(job_dir))
# Create a new graph and specify that as default
with tf.Graph().as_default():
with tf.device(tf.train.replica_device_setter()):
# Features and label tensors as read using filename queue
image, ground_truth, name = model.input_fn(train_files, num_epochs, shuffle=True, shared_name='train_queue')
# Returns the training graph and global step tensor
train_op, log_hook, train_summaries = model.model_fn(model.TRAIN, name, image, ground_truth,
num_channels, learning_rate)
# Hook that logs training to the console
hooks.append(log_hook)
train_summary_hook = tf.train.SummarySaverHook(save_steps=1, output_dir=get_summary_dir(job_dir),
summary_op=train_summaries)
hooks.append(train_summary_hook)
# Creates a MonitoredSession for training
# MonitoredSession is a Session-like object that handles
# initialization, recovery and hooks
# https://www.tensorflow.org/api_docs/python/tf/train/MonitoredTrainingSession
with tf.train.MonitoredTrainingSession(master=target,
is_chief=is_chief,
checkpoint_dir=job_dir,
hooks=hooks,
save_checkpoint_secs=60*3,
save_summaries_steps=1,
log_step_count_steps=5) as session:
# Run the training graph which returns the step number as tracked by
# the global step tensor.
# When train epochs is reached, session.should_stop() will be true.
while not session.should_stop():
session.run(train_op)
评论列表
文章目录