def export(self, last_checkpoint, output_dir):
"""Builds a prediction graph and xports the model.
Args:
last_checkpoint: The latest checkpoint from training.
output_dir: Path to the folder to be used to output the model.
"""
logging.info('Exporting prediction graph to %s', output_dir)
with tf.Session(graph=tf.Graph()) as sess:
# Build and save prediction meta graph and trained variable values.
self.build_prediction_graph()
# Remove this if once Tensorflow 0.12 is standard.
try:
init_op = tf.global_variables_initializer()
except AttributeError:
init_op = tf.initialize_all_variables()
sess.run(init_op)
trained_saver = tf.train.Saver()
trained_saver.restore(sess, last_checkpoint)
saver = tf.train.Saver()
saver.export_meta_graph(filename=os.path.join(output_dir, 'export.meta'))
saver.save(
sess, os.path.join(output_dir, 'export'), write_meta_graph=False)
评论列表
文章目录