base_model.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:deepmodels 作者: learningsociety 项目源码 文件源码
def load_model(ckpt_dir, variables_to_restore=None):
  """Load model weights.

  Assuming the model graph has been built.

  Args:
    ckpt_dir: checkpoint directory.
    variables_to_restore: which variables to load from checkpoint.
  """
  if not os.path.exists(ckpt_dir):
    print "checkpoint dir {} not exist.".format(ckpt_dir)
    return

  ckpts = glob.glob(os.path.join(ckpt_dir, "*.ckpt*"))
  ckpt = ckpts[0]

  if variables_to_restore is None:
    saver = tf.train.Saver()
  else:
    saver = tf.train.Saver(variables_to_restore)
  with tf.Session() as sess:
    # ckpt = tf.train.latest_checkpoint(ckpt_dir)
    if ckpt:
      saver.restore(sess, ckpt)
      print "model loaded from {}".format(ckpt)
    else:
      print "unable to load model from {}".format(ckpt)
  # another way.
  # slim.assign_from_checkpoint_fn(
  #     ckpt,
  #     variables_to_restore,
  #     ignore_missing_vars=True)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号