def __init__(self, input_args, dest_nodes = None):
super(TensorflowParser, self).__init__()
# load model files into Keras graph
from six import string_types as _string_types
if isinstance(input_args, _string_types):
model = TensorflowParser._load_meta(input_args)
elif isinstance(input_args, tuple):
model = TensorflowParser._load_meta(input_args[0])
self.ckpt_data = TensorflowParser._load_weights(input_args[1])
self.weight_loaded = True
if dest_nodes != None:
from tensorflow.python.framework.graph_util import extract_sub_graph
model = extract_sub_graph(model, dest_nodes.split(','))
# Build network graph
self.tf_graph = TensorflowGraph(model)
self.tf_graph.build()
python类extract_sub_graph()的实例源码
def save_graph_only(sess, output_file_path, output_node_names, as_text=False):
"""Save a small version of the graph based on a session and the output node names."""
for node in sess.graph_def.node:
node.device = ''
graph_def = graph_util.extract_sub_graph(sess.graph_def, output_node_names)
output_dir, output_filename = os.path.split(output_file_path)
graph_io.write_graph(graph_def, output_dir, output_filename, as_text=as_text)
def remove_dead_nodes(self, output_names):
"""Removes nodes that are no longer needed for inference from the graph."""
old_output_graph = self.output_graph
self.output_graph = graph_util.extract_sub_graph(old_output_graph,
output_names)
def remove_dead_nodes(self, output_names):
"""Removes nodes that are no longer needed for inference from the graph."""
old_output_graph = self.output_graph
self.output_graph = graph_util.extract_sub_graph(old_output_graph,
output_names)
def strip_unused(input_graph, input_binary, output_graph, input_node_names,
output_node_names, placeholder_type_enum):
"""Removes unused nodes from a graph."""
if not tf.gfile.Exists(input_graph):
print("Input graph file '" + input_graph + "' does not exist!")
return -1
if not output_node_names:
print("You need to supply the name of a node to --output_node_names.")
return -1
input_graph_def = tf.GraphDef()
mode = "rb" if input_binary else "r"
with tf.gfile.FastGFile(input_graph, mode) as f:
if input_binary:
input_graph_def.ParseFromString(f.read())
else:
text_format.Merge(f.read(), input_graph_def)
# Here we replace the nodes we're going to override as inputs with
# placeholders so that any unused nodes that are inputs to them are
# automatically stripped out by extract_sub_graph().
input_node_names_list = input_node_names.split(",")
inputs_replaced_graph_def = tf.GraphDef()
for node in input_graph_def.node:
if node.name in input_node_names_list:
placeholder_node = tf.NodeDef()
placeholder_node.op = "Placeholder"
placeholder_node.name = node.name
placeholder_node.attr["dtype"].CopyFrom(tf.AttrValue(
type=placeholder_type_enum))
inputs_replaced_graph_def.node.extend([placeholder_node])
else:
inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])
output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
output_node_names.split(","))
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))