def strip_unused(input_graph, input_binary, output_graph, input_node_names,
output_node_names, placeholder_type_enum):
"""Removes unused nodes from a graph."""
if not tf.gfile.Exists(input_graph):
print("Input graph file '" + input_graph + "' does not exist!")
return -1
if not output_node_names:
print("You need to supply the name of a node to --output_node_names.")
return -1
input_graph_def = tf.GraphDef()
mode = "rb" if input_binary else "r"
with tf.gfile.FastGFile(input_graph, mode) as f:
if input_binary:
input_graph_def.ParseFromString(f.read())
else:
text_format.Merge(f.read(), input_graph_def)
# Here we replace the nodes we're going to override as inputs with
# placeholders so that any unused nodes that are inputs to them are
# automatically stripped out by extract_sub_graph().
input_node_names_list = input_node_names.split(",")
inputs_replaced_graph_def = tf.GraphDef()
for node in input_graph_def.node:
if node.name in input_node_names_list:
placeholder_node = tf.NodeDef()
placeholder_node.op = "Placeholder"
placeholder_node.name = node.name
placeholder_node.attr["dtype"].CopyFrom(tf.AttrValue(
type=placeholder_type_enum))
inputs_replaced_graph_def.node.extend([placeholder_node])
else:
inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])
output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
output_node_names.split(","))
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))
评论列表
文章目录