def _run_export(self):
export_dir = 'export_ckpt_' + re.findall('\d+', self._latest_checkpoint)[-1]
tf.logging.info('Exporting model from checkpoint {0}'.format(self._latest_checkpoint))
prediction_graph = tf.Graph()
try:
exporter = tf.saved_model.builder.SavedModelBuilder(os.path.join(self._checkpoint_dir, export_dir))
except IOError:
tf.logging.info('Checkpoint {0} already exported, continuing...'.format(self._latest_checkpoint))
return
with prediction_graph.as_default():
image, name, inputs_dict = model.serving_input_fn()
prediction_dict = model.model_fn(model.PREDICT, name, image, None, 6, None)
saver = tf.train.Saver()
inputs_info = {name: tf.saved_model.utils.build_tensor_info(tensor)
for name, tensor in inputs_dict.iteritems()}
output_info = {name: tf.saved_model.utils.build_tensor_info(tensor)
for name, tensor in prediction_dict.iteritems()}
signature_def = tf.saved_model.signature_def_utils.build_signature_def(
inputs=inputs_info,
outputs=output_info,
method_name=sig_constants.PREDICT_METHOD_NAME
)
with tf.Session(graph=prediction_graph) as session:
saver.restore(session, self._latest_checkpoint)
exporter.add_meta_graph_and_variables(
session,
tags=[tf.saved_model.tag_constants.SERVING],
signature_def_map={sig_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_def},
legacy_init_op=my_main_op()
)
exporter.save()
CheckpointExporterHook.py 文件源码
python
阅读 30
收藏 0
点赞 0
评论 0
评论列表
文章目录