def get_var_list_to_restore():
"""Choosing which vars to restore, ignore vars by setting --checkpoint_exclude_scopes """
variables_to_restore = []
if FLAGS.checkpoint_exclude_scopes is not None:
exclusions = [scope.strip()
for scope in FLAGS.checkpoint_exclude_scopes.split(',')]
# build restore list
for var in tf.model_variables():
excluded = False
for exclusion in exclusions:
if var.name.startswith(exclusion):
excluded = True
break
if not excluded:
variables_to_restore.append(var)
else:
variables_to_restore = tf.model_variables()
variables_to_restore_final = []
if FLAGS.checkpoint_include_scopes is not None:
includes = [
scope.strip()
for scope in FLAGS.checkpoint_include_scopes.split(',')
]
for var in variables_to_restore:
included = False
for include in includes:
if var.name.startswith(include):
included = True
break
if included:
variables_to_restore_final.append(var)
else:
variables_to_restore_final = variables_to_restore
return variables_to_restore_final
评论列表
文章目录