def setUp(self):
self.base_path = os.path.join(tf.test.get_temp_dir(), "no_vars")
if not os.path.exists(self.base_path):
os.mkdir(self.base_path)
# Create a simple graph with a variable, then convert variables to
# constants and export the graph.
with tf.Graph().as_default() as g:
x = tf.placeholder(tf.float32, name="x")
w = tf.Variable(3.0)
y = tf.sub(w * x, 7.0, name="y") # pylint: disable=unused-variable
tf.add_to_collection("meta", "this is meta")
with self.test_session(graph=g) as session:
tf.initialize_all_variables().run()
new_graph_def = graph_util.convert_variables_to_constants(
session, g.as_graph_def(), ["y"])
filename = os.path.join(self.base_path, constants.META_GRAPH_DEF_FILENAME)
tf.train.export_meta_graph(
filename, graph_def=new_graph_def, collection_list=["meta"])
python类convert_variables_to_constants()的实例源码
def setUp(self):
self.base_path = os.path.join(tf.test.get_temp_dir(), "no_vars")
if not os.path.exists(self.base_path):
os.mkdir(self.base_path)
# Create a simple graph with a variable, then convert variables to
# constants and export the graph.
with tf.Graph().as_default() as g:
x = tf.placeholder(tf.float32, name="x")
w = tf.Variable(3.0)
y = tf.sub(w * x, 7.0, name="y") # pylint: disable=unused-variable
tf.add_to_collection("meta", "this is meta")
with self.test_session(graph=g) as session:
tf.global_variables_initializer().run()
new_graph_def = graph_util.convert_variables_to_constants(
session, g.as_graph_def(), ["y"])
filename = os.path.join(self.base_path, constants.META_GRAPH_DEF_FILENAME)
tf.train.export_meta_graph(
filename, graph_def=new_graph_def, collection_list=["meta"])
def freeze_graph(model_folder):
from tensorflow.python.framework import graph_util
checkpoint = tf.train.get_checkpoint_state(args.model_folder)
input_checkpoint = checkpoint.model_checkpoint_path
absolute_model_folder = "/".join(input_checkpoint.split('/')[:-1])
output_node_names = "Accuracy/predictions"
output_graph = absolute_model_folder + "/frozen_model_2.pb"
clear_devices = True
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
with tf.Session() as sess:
saver.restore(sess, input_checkpoint)
output_graph_def = graph_util.convert_variables_to_constants(
sess, input_graph_def, output_node_names.split(","))
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))
def freeze_graph(
model_dir, output_nodes_list, output_graph_name='frozen_model.pb'
):
"""
reduce a saved model and metadata down to a deployable file
"""
from tensorflow.python.framework import graph_util
LOGGER.info('Attempting to freeze graph at {}'.format(model_dir))
checkpoint = tf.train.get_checkpoint_state(model_dir)
input_checkpoint = checkpoint.model_checkpoint_path
if input_checkpoint is None:
LOGGER.error('Cannot load checkpoint at {}'.format(model_dir))
return None
absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1])
output_graph = absolute_model_dir + '/' + output_graph_name
saver = tf.train.import_meta_graph(input_checkpoint + '.meta',
clear_devices=True)
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
with tf.Session() as sess:
saver.restore(sess, input_checkpoint)
output_graph_def = graph_util.convert_variables_to_constants(
sess, input_graph_def, output_nodes_list
)
with tf.gfile.GFile(output_graph, 'wb') as f:
f.write(output_graph_def.SerializeToString())
LOGGER.info('Froze graph with {} ops'.format(
len(output_graph_def.node)
))
return output_graph
def freeze_graph(
model_dir, output_nodes_list, output_graph_name='frozen_model.pb'
):
"""
reduce a saved model and metadata down to a deployable file; following
https://blog.metaflow.fr/tensorflow-how-to-freeze-a-model-and-serve-it-with-a-python-api-d4f3596b3adc
output_nodes_list = e.g., ['softmax_linear/logits']
"""
from tensorflow.python.framework import graph_util
LOGGER.info('Attempting to freeze graph at {}'.format(model_dir))
checkpoint = tf.train.get_checkpoint_state(model_dir)
input_checkpoint = checkpoint.model_checkpoint_path
if input_checkpoint is None:
LOGGER.error('Cannot load checkpoint at {}'.format(model_dir))
return None
absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1])
output_graph = absolute_model_dir + '/' + output_graph_name
saver = tf.train.import_meta_graph(input_checkpoint + '.meta',
clear_devices=True)
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
with tf.Session() as sess:
saver.restore(sess, input_checkpoint)
output_graph_def = graph_util.convert_variables_to_constants(
sess, input_graph_def, output_nodes_list
)
with tf.gfile.GFile(output_graph, 'wb') as f:
f.write(output_graph_def.SerializeToString())
LOGGER.info('Froze graph with {} ops'.format(
len(output_graph_def.node)
))
return output_graph
def freeze_graph(
model_dir, output_nodes_list, output_graph_name='frozen_model.pb'
):
"""
reduce a saved model and metadata down to a deployable file; following
https://blog.metaflow.fr/tensorflow-how-to-freeze-a-model-and-serve-it-with-a-python-api-d4f3596b3adc
output_nodes_list = e.g., ['softmax_linear/logits']
"""
from tensorflow.python.framework import graph_util
LOGGER.info('Attempting to freeze graph at {}'.format(model_dir))
checkpoint = tf.train.get_checkpoint_state(model_dir)
input_checkpoint = checkpoint.model_checkpoint_path
if input_checkpoint is None:
LOGGER.error('Cannot load checkpoint at {}'.format(model_dir))
return None
absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1])
output_graph = absolute_model_dir + '/' + output_graph_name
saver = tf.train.import_meta_graph(input_checkpoint + '.meta',
clear_devices=True)
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
with tf.Session() as sess:
saver.restore(sess, input_checkpoint)
output_graph_def = graph_util.convert_variables_to_constants(
sess, input_graph_def, output_nodes_list
)
with tf.gfile.GFile(output_graph, 'wb') as f:
f.write(output_graph_def.SerializeToString())
LOGGER.info('Froze graph with {} ops'.format(
len(output_graph_def.node)
))
return output_graph
def freeze_graph_def(sess, input_graph_def, output_node_names):
for node in input_graph_def.node:
if node.op == 'RefSwitch':
node.op = 'Switch'
for index in xrange(len(node.input)):
if 'moving_' in node.input[index]:
node.input[index] = node.input[index] + '/read'
elif node.op == 'AssignSub':
node.op = 'Sub'
if 'use_locking' in node.attr: del node.attr['use_locking']
elif node.op == 'AssignAdd':
node.op = 'Add'
if 'use_locking' in node.attr: del node.attr['use_locking']
# Get the list of important nodes
whitelist_names = []
for node in input_graph_def.node:
if (node.name.startswith('InceptionResnetV1') or node.name.startswith('embeddings') or
node.name.startswith('phase_train') or node.name.startswith('Bottleneck') or node.name.startswith('Logits')):
whitelist_names.append(node.name)
# Replace all the variables in the graph with constants of the same values
output_graph_def = graph_util.convert_variables_to_constants(
sess, input_graph_def, output_node_names.split(","),
variable_names_whitelist=whitelist_names)
return output_graph_def
def save_graph_to_file(sess, graph, graph_file_name):
output_graph_def = graph_util.convert_variables_to_constants(
sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
with gfile.FastGFile(graph_file_name, 'wb') as f:
f.write(output_graph_def.SerializeToString())
return
def generatePB(pb_dest = "model_mnist_bnn.pb"):
gd = sess.graph.as_graph_def()
gd2 = graph_util.convert_variables_to_constants(sess, gd, ['output'])
with gfile.FastGFile(pb_dest, 'wb') as f:
f.write(gd2.SerializeToString())
print('pb saved')
def generatePB(pb_dest = "model.pb"):
gd = sess.graph.as_graph_def()
gd2 = graph_util.convert_variables_to_constants(sess, gd, ['output'])
with gfile.FastGFile(pb_dest, 'wb') as f:
f.write(gd2.SerializeToString())
print('pb saved')
def generatePB(pb_dest = "cifar_bnn_new.pb"):
gd = sess.graph.as_graph_def()
gd2 = graph_util.convert_variables_to_constants(sess, gd, ['output'])
with gfile.FastGFile(pb_dest, 'wb') as f:
f.write(gd2.SerializeToString())
print('pb saved')
def freeze_graph(model_folder):
# We retrieve our checkpoint fullpath
checkpoint = tf.train.get_checkpoint_state(model_folder)
input_checkpoint = checkpoint.model_checkpoint_path
# We precise the file fullname of our freezed graph
absolute_model_folder = "/".join(input_checkpoint.split('/')[:-1])
output_graph = absolute_model_folder + "/frozen_model.pb"
# Before exporting our graph, we need to precise what is our output node
# NOTE: this variables is plural, because you can have multiple output nodes
output_node_names = "Accuracy/predictions"
# We clear the devices, to allow TensorFlow to control on the loading where it wants operations to be calculated
clear_devices = True
# We import the meta graph and retrive a Saver
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)
# We retrieve the protobuf graph definition
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
# We start a session and restore the graph weights
with tf.Session() as sess:
saver.restore(sess, input_checkpoint)
# We use a built-in TF helper to export variables to constant
output_graph_def = graph_util.convert_variables_to_constants(
sess,
input_graph_def,
output_node_names.split(",") # We split on comma for convenience
)
# Finally we serialize and dump the output graph to the filesystem
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))
def freeze_graph_def(sess, input_graph_def, output_node_names):
for node in input_graph_def.node:
if node.op == 'RefSwitch':
node.op = 'Switch'
for index in xrange(len(node.input)):
if 'moving_' in node.input[index]:
node.input[index] = node.input[index] + '/read'
elif node.op == 'AssignSub':
node.op = 'Sub'
if 'use_locking' in node.attr: del node.attr['use_locking']
elif node.op == 'AssignAdd':
node.op = 'Add'
if 'use_locking' in node.attr: del node.attr['use_locking']
# Get the list of important nodes
whitelist_names = []
for node in input_graph_def.node:
if (node.name.startswith('InceptionResnetV1') or node.name.startswith('embeddings') or
node.name.startswith('phase_train') or node.name.startswith('Bottleneck') or node.name.startswith('Logits')):
whitelist_names.append(node.name)
# Replace all the variables in the graph with constants of the same values
output_graph_def = graph_util.convert_variables_to_constants(
sess, input_graph_def, output_node_names.split(","),
variable_names_whitelist=whitelist_names)
return output_graph_def
def run():
log.info('Run freeze restore')
y = tf.Variable([float(88.8), float(5)], name='y1')
# print(y.op.node_def)
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
# sess.run(y)
g = sess.graph
g_def = g.as_graph_def()
# print node names
# print([n.name for n in g_def.node])
# constants
constants = graph_util.convert_variables_to_constants(
sess, g_def, ['y1'])
# serialize
s = constants.SerializeToString()
# print(len(g_def.node))
print_nodes(g.as_graph_def(), 'before restore:')
_ = restore_graph(s)
print_nodes(g.as_graph_def(), 'after restore:')
t = g.get_tensor_by_name('restore/y1:0')
sess.run(y.assign(y + t))
print(sess.run(y))
# print(len(g_def.node))
# print(sess.run(y.assign([float(99.9)])))
# print(n)
# print(sess.run(y.assign(n)))
# g2 = tf.Graph()
# g2_def = g2.as_graph_def()
# print([n.name for n in g2_def.node])
# run()
def _set_variable_and_publish(self, sess, iteration_id, transaction_id,
group_id):
# v = variable
# s = v.to_proto().SerializeToString()
# h = ':'.join('{:02x}'.format(ord(c)) for c in s)
variable_names = [var.op.name for var in self.variables]
g = sess.graph
g_def = g.as_graph_def()
constants = graph_util.convert_variables_to_constants(
sess, g_def, variable_names)
s = constants.SerializeToString()
parallel_count = self.infra_info['parallel_count']
self.rc.set(transaction_id, s)
message = json.dumps({
'key': 'set_variable',
'transaction_id': transaction_id,
'group_id': group_id,
'variables': variable_names,
'worker_id': self.worker_id,
'train_id': self.train_id,
'iteration_id': iteration_id,
'parallel_count': parallel_count
})
self.r.publish(channel=channel, message=message)
log.debug('pub %s' % iteration_id)
return len(s)
def _set_variable_and_publish(self, sess, iteration_id, variables,
transaction_id, group_id):
# v = variable
# s = v.to_proto().SerializeToString()
# h = ':'.join('{:02x}'.format(ord(c)) for c in s)
variable_names = [var.op.name for var in variables]
g = sess.graph
g_def = g.as_graph_def()
constants = graph_util.convert_variables_to_constants(
sess, g_def, variable_names)
s = constants.SerializeToString()
parallel_count = self.infra_info['parallel_count']
self.rc.set(transaction_id, s)
message = json.dumps({
'key': 'set_variable',
'transaction_id': transaction_id,
'group_id': group_id,
'variables': variable_names,
'worker_id': self.worker_id,
'train_id': self.train_id,
'iteration_id': iteration_id,
'parallel_count': parallel_count
})
self.r.publish(channel=channel, message=message)
self._log('pub %s' % iteration_id)
return len(s)
def save_graph_to_file(sess, graph, graph_file_name):
output_graph_def = graph_util.convert_variables_to_constants(
sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
with gfile.FastGFile(graph_file_name, 'wb') as f:
f.write(output_graph_def.SerializeToString())
return
def freeze_graph(model_folder, net_name):
# We retrieve our checkpoint fullpath
checkpoint = tf.train.get_checkpoint_state(model_folder)
input_checkpoint = checkpoint.model_checkpoint_path
# We precise the file fullname of our freezed graph
absolute_model_folder = "/".join(input_checkpoint.split('/')[:-1])
output_graph = absolute_model_folder + "/%s.pb" % net_name
# Before exporting our graph, we need to precise what is our output node
# This is how TF decides what part of the Graph he has to keep and what part it can dump
# NOTE: this variable is plural, because you can have multiple output nodes
output_node_names = "Placeholder,Placeholder_1,Placeholder_2,out/add"
# We clear devices to allow TensorFlow to control on which device it will load operations
clear_devices = True
# We import the meta graph and retrieve a Saver
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)
# We retrieve the protobuf graph definition
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
# We start a session and restore the graph weights
with tf.Session() as sess:
saver.restore(sess, input_checkpoint)
graph_def = sess.graph.as_graph_def()
for node in graph_def.node:
print node.name
# We use a built-in TF helper to export variables to constants
output_graph_def = graph_util.convert_variables_to_constants(
sess, # The session is used to retrieve the weights
input_graph_def, # The graph_def is used to retrieve the nodes
output_node_names.split(",") # The output node names are used to select the usefull nodes
)
# Finally we serialize and dump the output graph to the filesystem
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))
def export_graph(input_path, output_path, output_nodes, debug=False):
# todo: might want to look at http://stackoverflow.com/a/39578062/195651
checkpoint = tf.train.latest_checkpoint(input_path)
importer = tf.train.import_meta_graph(checkpoint + '.meta', clear_devices=True)
graph = tf.get_default_graph() # type: tf.Graph
gd = graph.as_graph_def() # type: tf.GraphDef
if debug:
op_names = [op.name for op in graph.get_operations()]
print(op_names)
# fix batch norm nodes
# https://github.com/tensorflow/tensorflow/issues/3628
for node in gd.node:
if node.op == 'RefSwitch':
node.op = 'Switch'
for index in range(len(node.input)):
if 'moving_' in node.input[index]:
node.input[index] += '/read'
elif node.op == 'AssignSub':
node.op = 'Sub'
if 'use_locking' in node.attr:
del node.attr['use_locking']
elif node.op == 'AssignAdd':
node.op = 'Add'
if 'use_locking' in node.attr:
del node.attr['use_locking']
if debug:
print('Freezing the graph ...')
with tf.Session() as sess:
importer.restore(sess, checkpoint)
output_graph_def = graph_util.convert_variables_to_constants(sess, gd, output_nodes)
tf.train.write_graph(output_graph_def, path.dirname(output_path), path.basename(output_path), as_text=False)
def freeze_graph(model_folder):
# We retrieve our checkpoint fullpath
checkpoint = tf.train.get_checkpoint_state(model_folder)
input_checkpoint = checkpoint.model_checkpoint_path
# We precise the file fullname of our freezed graph
absolute_model_folder = '/'.join(input_checkpoint.split('/')[:-1])
output_graph = absolute_model_folder + '/frozen_model.pb'
# Before exporting our graph, we need to precise what is our output node
# This is how TF decides what part of the Graph he has to keep and what part it can dump
# NOTE: this variable is plural, because you can have multiple output nodes
output_node_names = 'generate_output/output'
# We clear devices to allow TensorFlow to control on which device it will load operations
clear_devices = True
# We import the meta graph and retrieve a Saver
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)
# We retrieve the protobuf graph definition
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
# We start a session and restore the graph weights
with tf.Session() as sess:
saver.restore(sess, input_checkpoint)
# We use a built-in TF helper to export variables to constants
output_graph_def = graph_util.convert_variables_to_constants(
sess, # The session is used to retrieve the weights
input_graph_def, # The graph_def is used to retrieve the nodes
output_node_names.split(",") # The output node names are used to select the usefull nodes
)
# Finally we serialize and dump the output graph to the filesystem
with tf.gfile.GFile(output_graph, 'wb') as f:
f.write(output_graph_def.SerializeToString())
print('%d ops in the final graph.' % len(output_graph_def.node))
def main(_):
output_node_names = "output_prob"
session_config = tf.ConfigProto()
session_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_fraction
with tf.Session(config=session_config) as sess:
ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
saver = tf.train.import_meta_graph(ckpt + '.meta')
if ckpt:
saver.restore(sess, ckpt)
# for node in input_graph_def.node:
# print(node.name, node.op, node.input)
# Retrieve the protobuf graph definition and fix the batch norm nodes
# Fix for bug of BN.
# Ref 1 Solution: https://github.com/davidsandberg/facenet/issues/161
# Ref 2 Official Issue: https://github.com/tensorflow/tensorflow/issues/3628
gd = sess.graph.as_graph_def()
for node in gd.node:
if node.op == 'RefSwitch':
node.op = 'Switch'
for index in range(len(node.input)):
if 'moving_' in node.input[index]:
node.input[index] = node.input[index] + '/read'
elif node.op == 'AssignSub':
node.op = 'Sub'
if 'use_locking' in node.attr: del node.attr['use_locking']
elif node.op == 'AssignAdd':
node.op = 'Add'
if 'use_locking' in node.attr: del node.attr['use_locking']
output_graph_def = graph_util.convert_variables_to_constants(
sess, # The session is used to retrieve the weights
gd, # The graph_def is used to retrieve the nodes
output_node_names.split(",") # The output node names are used to select the usefull nodes
)
with tf.gfile.GFile(os.path.join(FLAGS.model_dir, 'model.pb'), "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))
def convertGraph( modelPath, outdir, numoutputs, prefix, name):
'''
Converts an HD5F file to a .pb file for use with Tensorflow.
Args:
modelPath (str): path to the .h5 file
outdir (str): path to the output directory
numoutputs (int):
prefix (str): the prefix of the output aliasing
name (str):
Returns:
None
'''
#NOTE: If using Python > 3.2, this could be replaced with os.makedirs( name, exist_ok=True )
if not os.path.isdir(outdir):
os.mkdir(outdir)
K.set_learning_phase(0)
net_model = load_model(modelPath)
# Alias the outputs in the model - this sometimes makes them easier to access in TF
pred = [None]*numoutputs
pred_node_names = [None]*numoutputs
for i in range(numoutputs):
pred_node_names[i] = prefix+'_'+str(i)
pred[i] = tf.identity(net_model.output[i], name=pred_node_names[i])
print('Output nodes names are: ', pred_node_names)
sess = K.get_session()
# Write the graph in human readable
f = 'graph_def_for_reference.pb.ascii'
tf.train.write_graph(sess.graph.as_graph_def(), outdir, f, as_text=True)
print('Saved the graph definition in ascii format at: ', osp.join(outdir, f))
# Write the graph in binary .pb file
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import graph_io
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), pred_node_names)
graph_io.write_graph(constant_graph, outdir, name, as_text=False)
print('Saved the constant graph (ready for inference) at: ', osp.join(outdir, name))
def train_network(graph, batch_size, num_epochs, pb_file_path):
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
epoch_delta = 2
for epoch_index in range(num_epochs):
for i in range(12):
sess.run([graph['optimize']], feed_dict={
graph['x']: np.reshape(x_train[i], (1, 224, 224, 3)),
graph['y']: ([[1, 0]] if y_train[i] == 0 else [[0, 1]])
})
if epoch_index % epoch_delta == 0:
total_batches_in_train_set = 0
total_correct_times_in_train_set = 0
total_cost_in_train_set = 0.
for i in range(12):
return_correct_times_in_batch = sess.run(graph['correct_times_in_batch'], feed_dict={
graph['x']: np.reshape(x_train[i], (1, 224, 224, 3)),
graph['y']: ([[1, 0]] if y_train[i] == 0 else [[0, 1]])
})
mean_cost_in_batch = sess.run(graph['cost'], feed_dict={
graph['x']: np.reshape(x_train[i], (1, 224, 224, 3)),
graph['y']: ([[1, 0]] if y_train[i] == 0 else [[0, 1]])
})
total_batches_in_train_set += 1
total_correct_times_in_train_set += return_correct_times_in_batch
total_cost_in_train_set += (mean_cost_in_batch * batch_size)
total_batches_in_test_set = 0
total_correct_times_in_test_set = 0
total_cost_in_test_set = 0.
for i in range(3):
return_correct_times_in_batch = sess.run(graph['correct_times_in_batch'], feed_dict={
graph['x']: np.reshape(x_val[i], (1, 224, 224, 3)),
graph['y']: ([[1, 0]] if y_val[i] == 0 else [[0, 1]])
})
mean_cost_in_batch = sess.run(graph['cost'], feed_dict={
graph['x']: np.reshape(x_val[i], (1, 224, 224, 3)),
graph['y']: ([[1, 0]] if y_val[i] == 0 else [[0, 1]])
})
total_batches_in_test_set += 1
total_correct_times_in_test_set += return_correct_times_in_batch
total_cost_in_test_set += (mean_cost_in_batch * batch_size)
acy_on_test = total_correct_times_in_test_set / float(total_batches_in_test_set * batch_size)
acy_on_train = total_correct_times_in_train_set / float(total_batches_in_train_set * batch_size)
print('Epoch - {:2d}, acy_on_test:{:6.2f}%({}/{}),loss_on_test:{:6.2f}, acy_on_train:{:6.2f}%({}/{}),loss_on_train:{:6.2f}'.format(epoch_index, acy_on_test*100.0,total_correct_times_in_test_set,
total_batches_in_test_set * batch_size,
total_cost_in_test_set,
acy_on_train * 100.0,
total_correct_times_in_train_set,
total_batches_in_train_set * batch_size,
total_cost_in_train_set))
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output"])
with tf.gfile.FastGFile(pb_file_path, mode='wb') as f:
f.write(constant_graph.SerializeToString())
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
output_node_names, restore_op_name, filename_tensor_name,
output_graph, clear_devices, initializer_nodes):
"""Converts all variables in a graph and checkpoint into constants."""
if not tf.gfile.Exists(input_graph):
print("Input graph file '" + input_graph + "' does not exist!")
return -1
if input_saver and not tf.gfile.Exists(input_saver):
print("Input saver file '" + input_saver + "' does not exist!")
return -1
# 'input_checkpoint' may be a prefix if we're using Saver V2 format
if not tf.train.checkpoint_exists(input_checkpoint):
print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
return -1
if not output_node_names:
print("You need to supply the name of a node to --output_node_names.")
return -1
input_graph_def = tf.GraphDef()
mode = "rb" if input_binary else "r"
with tf.gfile.FastGFile(input_graph, mode) as f:
if input_binary:
input_graph_def.ParseFromString(f.read())
else:
text_format.Merge(f.read().decode("utf-8"), input_graph_def)
# Remove all the explicit device specifications for this node. This helps to
# make the graph more portable.
if clear_devices:
for node in input_graph_def.node:
node.device = ""
_ = tf.import_graph_def(input_graph_def, name="")
with tf.Session() as sess:
if input_saver:
with tf.gfile.FastGFile(input_saver, mode) as f:
saver_def = tf.train.SaverDef()
if input_binary:
saver_def.ParseFromString(f.read())
else:
text_format.Merge(f.read(), saver_def)
saver = tf.train.Saver(saver_def=saver_def)
saver.restore(sess, input_checkpoint)
else:
sess.run([restore_op_name], {filename_tensor_name: input_checkpoint})
if initializer_nodes:
sess.run(initializer_nodes)
output_graph_def = graph_util.convert_variables_to_constants(
sess, input_graph_def, output_node_names.split(","))
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))
def _calculate_average_and_put(self, group_id, item, m):
keys = item['keys']
tf.reset_default_graph()
sess = tf.Session()
new_vars = []
m_cal_and_put = SimpleMeasurement('cal_and_put', m)
m_init = SimpleMeasurement('init', m)
init_op = tf.global_variables_initializer()
sess.run(init_op)
m_init.end_measure()
for v in item['variables']:
count = 0
name = 'average_%s' % v
ts = []
for key in keys:
raw = self.rc.get(key)
# TODO: check raw is not None
util.restore_graph(key, raw)
g = sess.graph
t = g.get_tensor_by_name('%s/%s:0' % (key, v))
ts.append(t)
count += 1
m_cal = SimpleMeasurement('cal', m)
avg = tf.foldl(tf.add, ts) / count
new_var = tf.Variable(avg, name=name)
sess.run(new_var.initializer)
sess.run(new_var)
new_vars.append(name)
m_cal.end_measure()
g = sess.graph
g_def = g.as_graph_def()
constants = graph_util.convert_variables_to_constants(
sess, g_def, new_vars)
s = constants.SerializeToString()
self.rc.set(group_id, s)
sess.close()
m_cal_and_put.end_measure()
def freeze_graph(model_folder):
# We retrieve our checkpoint fullpath
checkpoint = tf.train.get_checkpoint_state(model_folder)
input_checkpoint = checkpoint.model_checkpoint_path
# We precise the file fullname of our freezed graph
absolute_model_folder = "/".join(input_checkpoint.split('/')[:-1])
output_graph = absolute_model_folder + "/frozen_model.pb"
# Before exporting our graph, we need to precise what is our output node
# this variables is plural, because you can have multiple output nodes
# freeze?????????????,???????????????
# ??????????????
# ?????????,freeze????????????????????,??????????
# ??,output_node_names?????????????
output_node_names = "softmaxLayer/Softmax"
# We clear the devices, to allow TensorFlow to control on the loading where it wants operations to be calculated
clear_devices = True
# We import the meta graph and retrive a Saver
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)
# We retrieve the protobuf graph definition
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
# We start a session and restore the graph weights
# ???????????????,????????????,???????????,??????frozen
# ???????????????
with tf.Session() as sess:
saver.restore(sess, input_checkpoint)
# We use a built-in TF helper to export variables to constant
output_graph_def = graph_util.convert_variables_to_constants(
sess,
input_graph_def,
output_node_names.split(",") # We split on comma for convenience
)
# Finally we serialize and dump the output graph to the filesystem
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
output_node_names, restore_op_name, filename_tensor_name,
output_graph, clear_devices, initializer_nodes, verbose=True):
"""Converts all variables in a graph and checkpoint into constants."""
if not tf.gfile.Exists(input_graph):
print("Input graph file '" + input_graph + "' does not exist!")
return -1
if input_saver and not tf.gfile.Exists(input_saver):
print("Input saver file '" + input_saver + "' does not exist!")
return -1
if not tf.gfile.Glob(input_checkpoint):
print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
return -1
if not output_node_names:
print("You need to supply the name of a node to --output_node_names.")
return -1
input_graph_def = tf.GraphDef()
mode = "rb" if input_binary else "r"
with tf.gfile.FastGFile(input_graph, mode) as f:
if input_binary:
input_graph_def.ParseFromString(f.read())
else:
text_format.Merge(f.read(), input_graph_def)
# Remove all the explicit device specifications for this node. This helps to
# make the graph more portable.
if clear_devices:
for node in input_graph_def.node:
node.device = ""
_ = tf.import_graph_def(input_graph_def, name="")
with tf.Session() as sess:
if input_saver:
with tf.gfile.FastGFile(input_saver, mode) as f:
saver_def = tf.train.SaverDef()
if input_binary:
saver_def.ParseFromString(f.read())
else:
text_format.Merge(f.read(), saver_def)
saver = tf.train.Saver(saver_def=saver_def)
saver.restore(sess, input_checkpoint)
else:
sess.run([restore_op_name], {filename_tensor_name: input_checkpoint})
if initializer_nodes:
sess.run(initializer_nodes)
output_graph_def = graph_util.convert_variables_to_constants(
sess, input_graph_def, output_node_names.split(","))
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
if verbose == True:
print("%d ops in the final graph." % len(output_graph_def.node))
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
output_node_names, restore_op_name, filename_tensor_name,
output_graph, clear_devices, initializer_nodes):
"""Converts all variables in a graph and checkpoint into constants."""
if not tf.gfile.Exists(input_graph):
print("Input graph file '" + input_graph + "' does not exist!")
return -1
if input_saver and not tf.gfile.Exists(input_saver):
print("Input saver file '" + input_saver + "' does not exist!")
return -1
if not tf.gfile.Glob(input_checkpoint):
print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
return -1
if not output_node_names:
print("You need to supply the name of a node to --output_node_names.")
return -1
input_graph_def = tf.GraphDef()
mode = "rb" if input_binary else "r"
with tf.gfile.FastGFile(input_graph, mode) as f:
if input_binary:
input_graph_def.ParseFromString(f.read())
else:
text_format.Merge(f.read(), input_graph_def)
# Remove all the explicit device specifications for this node. This helps to
# make the graph more portable.
if clear_devices:
for node in input_graph_def.node:
node.device = ""
_ = tf.import_graph_def(input_graph_def, name="")
with tf.Session() as sess:
if input_saver:
with tf.gfile.FastGFile(input_saver, mode) as f:
saver_def = tf.train.SaverDef()
if input_binary:
saver_def.ParseFromString(f.read())
else:
text_format.Merge(f.read(), saver_def)
saver = tf.train.Saver(saver_def=saver_def)
saver.restore(sess, input_checkpoint)
else:
sess.run([restore_op_name], {filename_tensor_name: input_checkpoint})
if initializer_nodes:
sess.run(initializer_nodes)
output_graph_def = graph_util.convert_variables_to_constants(
sess, input_graph_def, output_node_names.split(","))
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
output_node_names, restore_op_name, filename_tensor_name,
output_graph, clear_devices, initializer_nodes):
"""Converts all variables in a graph and checkpoint into constants."""
if not tf.gfile.Exists(input_graph):
print("Input graph file '" + input_graph + "' does not exist!")
return -1
if input_saver and not tf.gfile.Exists(input_saver):
print("Input saver file '" + input_saver + "' does not exist!")
return -1
# 'input_checkpoint' may be a prefix if we're using Saver V2 format
if not tf.train.checkpoint_exists(input_checkpoint):
print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
return -1
if not output_node_names:
print("You need to supply the name of a node to --output_node_names.")
return -1
input_graph_def = tf.GraphDef()
mode = "rb" if input_binary else "r"
with tf.gfile.FastGFile(input_graph, mode) as f:
if input_binary:
input_graph_def.ParseFromString(f.read())
else:
text_format.Merge(f.read().decode("utf-8"), input_graph_def)
# Remove all the explicit device specifications for this node. This helps to
# make the graph more portable.
if clear_devices:
for node in input_graph_def.node:
node.device = ""
_ = tf.import_graph_def(input_graph_def, name="")
with tf.Session() as sess:
if input_saver:
with tf.gfile.FastGFile(input_saver, mode) as f:
saver_def = tf.train.SaverDef()
if input_binary:
saver_def.ParseFromString(f.read())
else:
text_format.Merge(f.read(), saver_def)
saver = tf.train.Saver(saver_def=saver_def)
saver.restore(sess, input_checkpoint)
else:
sess.run([restore_op_name], {filename_tensor_name: input_checkpoint})
if initializer_nodes:
sess.run(initializer_nodes)
output_graph_def = graph_util.convert_variables_to_constants(
sess, input_graph_def, output_node_names.split(","))
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
output_node_names, restore_op_name, filename_tensor_name,
output_graph, clear_devices, initializer_nodes):
"""Converts all variables in a graph and checkpoint into constants."""
if not tf.gfile.Exists(input_graph):
print("Input graph file '" + input_graph + "' does not exist!")
return -1
if input_saver and not tf.gfile.Exists(input_saver):
print("Input saver file '" + input_saver + "' does not exist!")
return -1
if not tf.gfile.Glob(input_checkpoint):
print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
return -1
if not output_node_names:
print("You need to supply the name of a node to --output_node_names.")
return -1
input_graph_def = tf.GraphDef()
mode = "rb" if input_binary else "r"
with tf.gfile.FastGFile(input_graph, mode) as f:
if input_binary:
input_graph_def.ParseFromString(f.read())
else:
text_format.Merge(f.read(), input_graph_def)
# Remove all the explicit device specifications for this node. This helps to
# make the graph more portable.
if clear_devices:
for node in input_graph_def.node:
node.device = ""
_ = tf.import_graph_def(input_graph_def, name="")
with tf.Session() as sess:
if input_saver:
with tf.gfile.FastGFile(input_saver, mode) as f:
saver_def = tf.train.SaverDef()
if input_binary:
saver_def.ParseFromString(f.read())
else:
text_format.Merge(f.read(), saver_def)
saver = tf.train.Saver(saver_def=saver_def)
saver.restore(sess, input_checkpoint)
else:
sess.run([restore_op_name], {filename_tensor_name: input_checkpoint})
if initializer_nodes:
sess.run(initializer_nodes)
output_graph_def = graph_util.convert_variables_to_constants(
sess, input_graph_def, output_node_names.split(","))
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))