def __init__(self, config, graph, model_scope, model_dir, model_file):
self.config = config
frozen_model = os.path.join(model_dir, model_file)
with tf.gfile.GFile(frozen_model, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# This model_scope adds a prefix to all the nodes in the graph
tf.import_graph_def(graph_def, input_map=None, return_elements=None,
name="{}/".format(model_scope))
# Uncomment the two lines below to look for the names of all the operations in the graph
# for op in graph.get_operations():
# print(op.name)
# Using the lines commented above to look for the tensor name of the input node
# Or you can figure it out in your original model, if you explicitly named it.
self.input_tensor = graph.get_tensor_by_name("{}/input_1:0".format(model_scope))
self.output_tensor = graph.get_tensor_by_name("{}/s1_output0:0".format(model_scope))
评论列表
文章目录