def load_frozen_graph(graph_dir, fix_nodes=True, entry=None, output=None):
with gfile.FastGFile(graph_dir, "rb") as file:
graph_def = tf.GraphDef()
graph_def.ParseFromString(file.read())
if fix_nodes:
for node in graph_def.node:
if node.op == 'RefSwitch':
node.op = 'Switch'
for index in range(len(node.input)):
if 'moving_' in node.input[index]:
node.input[index] = node.input[index] + '/read'
elif node.op == 'AssignSub':
node.op = 'Sub'
if 'use_locking' in node.attr:
del node.attr['use_locking']
tf.import_graph_def(graph_def, name="")
if entry is not None:
entry = tf.get_default_graph().get_tensor_by_name(entry)
if output is not None:
output = tf.get_default_graph().get_tensor_by_name(output)
return entry, output
评论列表
文章目录