def freeze(model_scope, model_dir, model_file):
"""
Args:
model_scope: The prefix of all variables in the model.
model_dir: The full path to the folder in which the result file locates.
model_file: The file that saves the training results, without file suffix / extension.
"""
saver = tf.train.import_meta_graph(os.path.join(model_dir, model_file + ".meta"))
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
with tf.Session() as sess:
saver.restore(sess, os.path.join(model_dir, model_file))
print("# All operations:")
for op in graph.get_operations():
print(op.name)
output_node_names = [v.name.split(":")[0] for v in tf.trainable_variables()]
output_node_names.append("{}/readout/logits".format(model_scope))
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
input_graph_def,
output_node_names
)
output_file = os.path.join(model_dir, model_file + ".pb")
with tf.gfile.GFile(output_file, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("Freezed model was saved as {}.pb.".format(model_file))
评论列表
文章目录