def load_model(sess, model_path):
if os.path.isfile(model_path):
# A protobuf file with a frozen graph
print('Model filename: %s' % model_path)
with gfile.FastGFile(model_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
else:
# A directory containing a metagraph file and a checkpoint file
print('Model directory: %s' % model_path)
meta_file, ckpt_file = get_model_filenames(model_path)
print('Metagraph file: %s' % meta_file)
print('Checkpoint file: %s' % ckpt_file)
saver = tf.train.import_meta_graph(os.path.join(model_path, meta_file), clear_devices=True)
saver.restore(sess, os.path.join(model_path, ckpt_file))
评论列表
文章目录