def export_graph(input_path, output_path, output_nodes, debug=False):
# todo: might want to look at http://stackoverflow.com/a/39578062/195651
checkpoint = tf.train.latest_checkpoint(input_path)
importer = tf.train.import_meta_graph(checkpoint + '.meta', clear_devices=True)
graph = tf.get_default_graph() # type: tf.Graph
gd = graph.as_graph_def() # type: tf.GraphDef
if debug:
op_names = [op.name for op in graph.get_operations()]
print(op_names)
# fix batch norm nodes
# https://github.com/tensorflow/tensorflow/issues/3628
for node in gd.node:
if node.op == 'RefSwitch':
node.op = 'Switch'
for index in range(len(node.input)):
if 'moving_' in node.input[index]:
node.input[index] += '/read'
elif node.op == 'AssignSub':
node.op = 'Sub'
if 'use_locking' in node.attr:
del node.attr['use_locking']
elif node.op == 'AssignAdd':
node.op = 'Add'
if 'use_locking' in node.attr:
del node.attr['use_locking']
if debug:
print('Freezing the graph ...')
with tf.Session() as sess:
importer.restore(sess, checkpoint)
output_graph_def = graph_util.convert_variables_to_constants(sess, gd, output_nodes)
tf.train.write_graph(output_graph_def, path.dirname(output_path), path.basename(output_path), as_text=False)
评论列表
文章目录