def freeze_graph_def(sess, input_graph_def, output_node_names):
for node in input_graph_def.node:
if node.op == 'RefSwitch':
node.op = 'Switch'
for index in xrange(len(node.input)):
if 'moving_' in node.input[index]:
node.input[index] = node.input[index] + '/read'
elif node.op == 'AssignSub':
node.op = 'Sub'
if 'use_locking' in node.attr: del node.attr['use_locking']
elif node.op == 'AssignAdd':
node.op = 'Add'
if 'use_locking' in node.attr: del node.attr['use_locking']
# Get the list of important nodes
whitelist_names = []
for node in input_graph_def.node:
if (node.name.startswith('InceptionResnetV1') or node.name.startswith('embeddings') or
node.name.startswith('phase_train') or node.name.startswith('Bottleneck') or node.name.startswith('Logits')):
whitelist_names.append(node.name)
# Replace all the variables in the graph with constants of the same values
output_graph_def = graph_util.convert_variables_to_constants(
sess, input_graph_def, output_node_names.split(","),
variable_names_whitelist=whitelist_names)
return output_graph_def
评论列表
文章目录