def _run_export(self):
export_dir = 'export_ckpt_' + re.findall('\d+', self._latest_checkpoint)[-1]
tf.logging.info('Exporting model from checkpoint {0}'.format(self._latest_checkpoint))
prediction_graph = tf.Graph()
try:
exporter = tf.saved_model.builder.SavedModelBuilder(os.path.join(self._checkpoint_dir, export_dir))
except IOError:
tf.logging.info('Checkpoint {0} already exported, continuing...'.format(self._latest_checkpoint))
return
with prediction_graph.as_default():
image, name, inputs_dict = model.serving_input_fn()
prediction_dict = model.model_fn(model.PREDICT, name, image, None, 6, None)
saver = tf.train.Saver()
inputs_info = {name: tf.saved_model.utils.build_tensor_info(tensor)
for name, tensor in inputs_dict.iteritems()}
output_info = {name: tf.saved_model.utils.build_tensor_info(tensor)
for name, tensor in prediction_dict.iteritems()}
signature_def = tf.saved_model.signature_def_utils.build_signature_def(
inputs=inputs_info,
outputs=output_info,
method_name=sig_constants.PREDICT_METHOD_NAME
)
with tf.Session(graph=prediction_graph) as session:
saver.restore(session, self._latest_checkpoint)
exporter.add_meta_graph_and_variables(
session,
tags=[tf.saved_model.tag_constants.SERVING],
signature_def_map={sig_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_def},
legacy_init_op=my_main_op()
)
exporter.save()
python类model_fn()的实例源码
def run(target, is_chief, train_steps, job_dir, train_files, eval_files, num_epochs, learning_rate):
num_channels = 6
hooks = list()
# does not work well in distributed mode cause it only counts local steps (I think...)
hooks.append(tf.train.StopAtStepHook(train_steps))
if is_chief:
evaluation_graph = tf.Graph()
with evaluation_graph.as_default():
# Features and label tensors
image, ground_truth, name = model.input_fn(eval_files, 1, shuffle=False, shared_name=None)
# Returns dictionary of tensors to be evaluated
metric_dict = model.model_fn(model.EVAL, name, image, ground_truth, num_channels, learning_rate)
# hook that performs evaluation separate from training
hooks.append(EvaluationRunHook(job_dir, metric_dict, evaluation_graph))
hooks.append(CheckpointExporterHook(job_dir))
# Create a new graph and specify that as default
with tf.Graph().as_default():
with tf.device(tf.train.replica_device_setter()):
# Features and label tensors as read using filename queue
image, ground_truth, name = model.input_fn(train_files, num_epochs, shuffle=True, shared_name='train_queue')
# Returns the training graph and global step tensor
train_op, log_hook, train_summaries = model.model_fn(model.TRAIN, name, image, ground_truth,
num_channels, learning_rate)
# Hook that logs training to the console
hooks.append(log_hook)
train_summary_hook = tf.train.SummarySaverHook(save_steps=1, output_dir=get_summary_dir(job_dir),
summary_op=train_summaries)
hooks.append(train_summary_hook)
# Creates a MonitoredSession for training
# MonitoredSession is a Session-like object that handles
# initialization, recovery and hooks
# https://www.tensorflow.org/api_docs/python/tf/train/MonitoredTrainingSession
with tf.train.MonitoredTrainingSession(master=target,
is_chief=is_chief,
checkpoint_dir=job_dir,
hooks=hooks,
save_checkpoint_secs=60*3,
save_summaries_steps=1,
log_step_count_steps=5) as session:
# Run the training graph which returns the step number as tracked by
# the global step tensor.
# When train epochs is reached, session.should_stop() will be true.
while not session.should_stop():
session.run(train_op)
def build_and_run_exports(latest, job_dir, serving_input_fn, hidden_units):
"""Given the latest checkpoint file export the saved model.
Args:
latest (string): Latest checkpoint file
job_dir (string): Location of checkpoints and model files
name (string): Name of the checkpoint to be exported. Used in building the
export path.
hidden_units (list): Number of hidden units
learning_rate (float): Learning rate for the SGD
"""
prediction_graph = tf.Graph()
exporter = tf.saved_model.builder.SavedModelBuilder(
os.path.join(job_dir, 'export'))
with prediction_graph.as_default():
features, inputs_dict = serving_input_fn()
prediction_dict = model.model_fn(
model.PREDICT,
features.copy(),
None, # labels
hidden_units=hidden_units,
learning_rate=None # learning_rate unused in prediction mode
)
saver = tf.train.Saver()
inputs_info = {
name: tf.saved_model.utils.build_tensor_info(tensor)
for name, tensor in inputs_dict.iteritems()
}
output_info = {
name: tf.saved_model.utils.build_tensor_info(tensor)
for name, tensor in prediction_dict.iteritems()
}
signature_def = tf.saved_model.signature_def_utils.build_signature_def(
inputs=inputs_info,
outputs=output_info,
method_name=sig_constants.PREDICT_METHOD_NAME
)
with tf.Session(graph=prediction_graph) as session:
session.run([tf.local_variables_initializer(), tf.tables_initializer()])
saver.restore(session, latest)
exporter.add_meta_graph_and_variables(
session,
tags=[tf.saved_model.tag_constants.SERVING],
signature_def_map={
sig_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_def
},
legacy_init_op=main_op()
)
exporter.save()
def build_and_run_exports(latest, job_dir, name, serving_input_fn, hidden_units):
"""Given the latest checkpoint file export the saved model.
Args:
latest (string): Latest checkpoint file
job_dir (string): Location of checkpoints and model files
name (string): Name of the checkpoint to be exported. Used in building the
export path.
hidden_units (list): Number of hidden units
learning_rate (float): Learning rate for the SGD
"""
prediction_graph = tf.Graph()
exporter = tf.saved_model.builder.SavedModelBuilder(
os.path.join(job_dir, 'export', name))
with prediction_graph.as_default():
features, inputs_dict = serving_input_fn()
prediction_dict = model.model_fn(
model.PREDICT,
features,
None, # labels
hidden_units=hidden_units,
learning_rate=None # learning_rate unused in prediction mode
)
saver = tf.train.Saver()
inputs_info = {
name: tf.saved_model.utils.build_tensor_info(tensor)
for name, tensor in inputs_dict.iteritems()
}
output_info = {
name: tf.saved_model.utils.build_tensor_info(tensor)
for name, tensor in prediction_dict.iteritems()
}
signature_def = tf.saved_model.signature_def_utils.build_signature_def(
inputs=inputs_info,
outputs=output_info,
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
)
with tf.Session(graph=prediction_graph) as session:
session.run([tf.local_variables_initializer(), tf.tables_initializer()])
saver.restore(session, latest)
exporter.add_meta_graph_and_variables(
session,
tags=[tf.saved_model.tag_constants.SERVING],
signature_def_map={
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_def
},
)
exporter.save()