def get_restore_vars(self, save_file, all_vars=None):
"""Create the `var_list` init argument to tf.Saver from save_file.
Extracts the subset of variables from tf.global_variables that match the
name and shape of variables saved in the checkpoint file, and returns these
as a list of variables to restore.
To support multi-model training, a model prefix is prepended to all
tf global_variable names, although this prefix is stripped from
all variables before they are saved to a checkpoint. Thus,
Args:
save_file: path of tf.train.Saver checkpoint.
Returns:
dict: checkpoint variables.
"""
reader = tf.train.NewCheckpointReader(save_file)
var_shapes = reader.get_variable_to_shape_map()
log.info('Saved Vars:\n' + str(var_shapes.keys()))
var_shapes = { # Strip the prefix off saved var names.
strip_prefix_from_name(self.params['model_params']['prefix'], name): shape
for name, shape in var_shapes.items()}
# Map old vars from checkpoint to new vars via load_param_dict.
mapped_var_shapes = self.remap_var_list(var_shapes)
log.info('Saved shapes:\n' + str(mapped_var_shapes))
if all_vars is None:
all_vars = tf.global_variables() + tf.local_variables() # get list of all variables
all_vars = strip_prefix(self.params['model_params']['prefix'], all_vars)
# Specify which vars are to be restored vs. reinitialized.
if self.load_param_dict is None:
restore_vars = {name: var for name, var in all_vars.items() if name in mapped_var_shapes}
else:
# associate checkpoint names with actual variables
load_var_dict = {}
for ckpt_var_name, curr_var_name in self.load_param_dict.items():
for curr_name, curr_var in all_vars.items():
if curr_name == curr_var_name:
load_var_dict[ckpt_var_name] = curr_var
break
restore_vars = load_var_dict
restore_vars = self.filter_var_list(restore_vars)
# Ensure the vars to restored have the correct shape.
var_list = {}
for name, var in restore_vars.items():
var_shape = var.get_shape().as_list()
if var_shape == mapped_var_shapes[name]:
var_list[name] = var
return var_list
评论列表
文章目录