def build_and_run_exports(latest, job_dir, name, serving_input_fn, hidden_units):
"""Given the latest checkpoint file export the saved model.
Args:
latest (string): Latest checkpoint file
job_dir (string): Location of checkpoints and model files
name (string): Name of the checkpoint to be exported. Used in building the
export path.
hidden_units (list): Number of hidden units
learning_rate (float): Learning rate for the SGD
"""
prediction_graph = tf.Graph()
exporter = tf.saved_model.builder.SavedModelBuilder(
os.path.join(job_dir, 'export', name))
with prediction_graph.as_default():
features, inputs_dict = serving_input_fn()
prediction_dict = model.model_fn(
model.PREDICT,
features,
None, # labels
hidden_units=hidden_units,
learning_rate=None # learning_rate unused in prediction mode
)
saver = tf.train.Saver()
inputs_info = {
name: tf.saved_model.utils.build_tensor_info(tensor)
for name, tensor in inputs_dict.iteritems()
}
output_info = {
name: tf.saved_model.utils.build_tensor_info(tensor)
for name, tensor in prediction_dict.iteritems()
}
signature_def = tf.saved_model.signature_def_utils.build_signature_def(
inputs=inputs_info,
outputs=output_info,
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
)
with tf.Session(graph=prediction_graph) as session:
session.run([tf.local_variables_initializer(), tf.tables_initializer()])
saver.restore(session, latest)
exporter.add_meta_graph_and_variables(
session,
tags=[tf.saved_model.tag_constants.SERVING],
signature_def_map={
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_def
},
)
exporter.save()
评论列表
文章目录