def get_variables_in_scope(scope, collection=tf.GraphKeys.TRAINABLE_VARIABLES):
"""Returns a tuple `tf.Variable`s in a scope for a given collection.
Args:
scope: `tf.VariableScope` or string to retrieve variables from.
collection: Collection to restrict query to. By default this is
`tf.Graphkeys.TRAINABLE_VARIABLES`, which doesn't include non-trainable
variables such as moving averages.
Returns:
A tuple of `tf.Variable` objects.
"""
scope_name = get_variable_scope_name(scope)
# Escape the name in case it contains any "." characters. Add a closing slash
# so we will not search any scopes that have this scope name as a prefix.
scope_name = re.escape(scope_name) + "/"
return tuple(tf.get_collection(collection, scope_name))
评论列表
文章目录