def restore_checkpoint(self, sess, var_list=None, load_ema=True):
"""
Restores either the best weights or the most recent checkpoint, assuming the correct
variables have already been added to the tf default graph e.g., .get_prediction()
has been called the model stored in `self`.
Automatically detects if EMA weights exists, and if they do loads them instead
"""
checkpoint = self.get_best_weights()
if checkpoint is None:
print("Loading most recent checkpoint")
checkpoint = self.get_latest_checkpoint()
else:
print("Loading best weights")
if load_ema:
if var_list is None:
# Same default used by `Saver`
var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) + \
tf.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS)
# Automatically check if there are EMA variables, if so use those
reader = tf.train.NewCheckpointReader(checkpoint)
ema = tf.train.ExponentialMovingAverage(0)
ema_names = {ema.average_name(x): x for x in var_list
if reader.has_tensor(ema.average_name(x))}
if len(ema_names) > 0:
print("Found EMA weights, loading them")
ema_vars = set(x for x in ema_names.values())
var_list = {v.op.name: v for v in var_list if v not in ema_vars}
var_list.update(ema_names)
saver = tf.train.Saver(var_list)
saver.restore(sess, checkpoint)
评论列表
文章目录