def get_saver(scope, collections=(tf.GraphKeys.GLOBAL_VARIABLES,), # pylint: disable=redefined-outer-name
context=None, **kwargs):
"""Builds a `tf.train.Saver` for the scope or module, with normalized names.
The names of the variables are normalized to remove the scope prefix.
This allows the same variables to be restored into another similar scope or
module using a complementary `tf.train.Saver` object.
Args:
scope: Scope or module. Variables within will be saved or restored.
collections: Sequence of collections of variables to restrict
`tf.train.Saver` to. By default this is `tf.GraphKeys.GLOBAL_VARIABLES`
which includes moving averages variables as well as trainable variables.
context: Scope or module, identical to or parent of `scope`. If given, this
will be used as the stripped prefix.
**kwargs: Extra keyword arguments to pass to tf.train.Saver.
Returns:
A `tf.train.Saver` object for Variables in the scope or module.
"""
variable_map = {}
for collection in collections:
variable_map.update(get_normalized_variable_map(scope, collection, context))
return tf.train.Saver(var_list=variable_map, **kwargs)
评论列表
文章目录