def _get_vars_to_collections(variables):
"""Returns a dict mapping variables to the collections they appear in."""
var_to_collections = collections.defaultdict(lambda: [])
if isinstance(variables, dict):
variables = list(v for _, v in variable_map_items(variables))
for graph in set(v.graph for v in variables):
for collection_name in list(graph.collections):
entries = set(entry for entry in graph.get_collection(collection_name)
if isinstance(entry, tf.Variable))
# For legacy reasons, tf.GraphKeys.GLOBAL_VARIABLES == "variables".
# Correcting for this here, to avoid confusion.
if collection_name == tf.GraphKeys.GLOBAL_VARIABLES:
collection_name = "global_variables"
for var in entries.intersection(variables):
var_to_collections[var].append(collection_name)
return var_to_collections
评论列表
文章目录