def filter_var_list(self, var_list):
"""Filter checkpoint vars for those to be restored.
Args:
checkpoint_vars (list): Vars that can be restored from checkpoint.
to_restore (list[str] or regex, optional): Selects vars to restore.
Returns:
list: Variables to be restored from checkpoint.
"""
if not self.to_restore:
return var_list
elif isinstance(self.to_restore, re._pattern_type):
return {name: var for name, var in var_list.items()
if self.to_restore.match(name)}
elif isinstance(self.to_restore, list):
return {name: var for name, var in var_list.items()
if name in self.to_restore}
raise TypeError('to_restore ({}) unsupported.'.format(type(self.to_restore)))
评论列表
文章目录