def load_model(ckpt_dir, variables_to_restore=None):
"""Load model weights.
Assuming the model graph has been built.
Args:
ckpt_dir: checkpoint directory.
variables_to_restore: which variables to load from checkpoint.
"""
if not os.path.exists(ckpt_dir):
print "checkpoint dir {} not exist.".format(ckpt_dir)
return
ckpts = glob.glob(os.path.join(ckpt_dir, "*.ckpt*"))
ckpt = ckpts[0]
if variables_to_restore is None:
saver = tf.train.Saver()
else:
saver = tf.train.Saver(variables_to_restore)
with tf.Session() as sess:
# ckpt = tf.train.latest_checkpoint(ckpt_dir)
if ckpt:
saver.restore(sess, ckpt)
print "model loaded from {}".format(ckpt)
else:
print "unable to load model from {}".format(ckpt)
# another way.
# slim.assign_from_checkpoint_fn(
# ckpt,
# variables_to_restore,
# ignore_missing_vars=True)
评论列表
文章目录