def load(self, model_name='main'):
checkpoint_dir = os.path.join(self.checkpoint_dir, self.model_dir)
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
counter = int(next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0))
return True, counter
else:
print("Failed to find a checkpoint")
return False, 0
评论列表
文章目录