def freeze_graph(
model_dir, output_nodes_list, output_graph_name='frozen_model.pb'
):
"""
reduce a saved model and metadata down to a deployable file; following
https://blog.metaflow.fr/tensorflow-how-to-freeze-a-model-and-serve-it-with-a-python-api-d4f3596b3adc
output_nodes_list = e.g., ['softmax_linear/logits']
"""
from tensorflow.python.framework import graph_util
LOGGER.info('Attempting to freeze graph at {}'.format(model_dir))
checkpoint = tf.train.get_checkpoint_state(model_dir)
input_checkpoint = checkpoint.model_checkpoint_path
if input_checkpoint is None:
LOGGER.error('Cannot load checkpoint at {}'.format(model_dir))
return None
absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1])
output_graph = absolute_model_dir + '/' + output_graph_name
saver = tf.train.import_meta_graph(input_checkpoint + '.meta',
clear_devices=True)
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
with tf.Session() as sess:
saver.restore(sess, input_checkpoint)
output_graph_def = graph_util.convert_variables_to_constants(
sess, input_graph_def, output_nodes_list
)
with tf.gfile.GFile(output_graph, 'wb') as f:
f.write(output_graph_def.SerializeToString())
LOGGER.info('Froze graph with {} ops'.format(
len(output_graph_def.node)
))
return output_graph
评论列表
文章目录