def initialize(self, no_scratch=False):
"""Fetch record then uses tf's saver.restore."""
if self.do_restore:
# First, determine which checkpoint to use.
if self.from_ckpt is not None:
# Use a cached checkpoint file.
ckpt_filename = self.from_ckpt
log.info('Restoring variables from checkpoint %s ...' % ckpt_filename)
else:
# Otherwise, use a database checkpoint.
self.load_rec() if self.load_data is None else None
if self.load_data is not None:
rec, ckpt_filename = self.load_data
log.info('Restoring variables from record %s (step %d)...' %
(str(rec['_id']), rec['step']))
else:
# No db checkpoint to load.
ckpt_filename = None
if ckpt_filename is not None:
all_vars = tf.global_variables() + tf.local_variables() # get list of all variables
self.all_vars = strip_prefix(self.params['model_params']['prefix'], all_vars)
# Next, determine which vars should be restored from the specified checkpoint.
restore_vars = self.get_restore_vars(ckpt_filename, self.all_vars)
restore_stripped = strip_prefix(self.params['model_params']['prefix'], list(restore_vars.values()))
restore_names = [name for name, var in restore_stripped.items()]
# Actually load the vars.
log.info('Restored Vars:\n' + str(restore_names))
tf_saver_restore = tf.train.Saver(restore_vars)
tf_saver_restore.restore(self.sess, ckpt_filename)
log.info('... done restoring.')
# Reinitialize all other, unrestored vars.
unrestored_vars = [var for name, var in self.all_vars.items() if name not in restore_names]
unrestored_var_names = [name for name, var in self.all_vars.items() if name not in restore_names]
log.info('Unrestored Vars:\n' + str(unrestored_var_names))
self.sess.run(tf.variables_initializer(unrestored_vars)) # initialize variables not restored
assert len(self.sess.run(tf.report_uninitialized_variables())) == 0, (
self.sess.run(tf.report_uninitialized_variables()))
if not self.do_restore or (self.load_data is None and self.from_ckpt is None):
init_op_global = tf.global_variables_initializer()
self.sess.run(init_op_global)
init_op_local = tf.local_variables_initializer()
self.sess.run(init_op_local)
评论列表
文章目录