def _load(self, sess):
config = self.config
vars_ = {var.name.split(":")[0]: var for var in tf.all_variables()}
if config.load_ema:
ema = self.model.var_ema
for var in tf.trainable_variables():
del vars_[var.name.split(":")[0]]
vars_[ema.average_name(var)] = var
saver = tf.train.Saver(vars_, max_to_keep=config.max_to_keep)
if config.load_path:
save_path = config.load_path
elif config.load_step > 0:
save_path = os.path.join(config.save_dir, "{}-{}".format(config.model_name, config.load_step))
else:
save_dir = config.save_dir
checkpoint = tf.train.get_checkpoint_state(save_dir)
assert checkpoint is not None, "cannot load checkpoint at {}".format(save_dir)
save_path = checkpoint.model_checkpoint_path
print("Loading saved model from {}".format(save_path))
saver.restore(sess, save_path)
评论列表
文章目录