util.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:sonnet 作者: deepmind 项目源码 文件源码
def get_normalized_variable_map(scope_or_module,
                                collection=tf.GraphKeys.GLOBAL_VARIABLES,
                                context=None,
                                group_sliced_variables=True):
  """Builds map of `tf.Variable`s in scope or module with normalized names.

  The names of the variables are normalized to remove the scope prefix.

  Args:
    scope_or_module: Scope or module to build map from.
    collection: Collection to restrict query to. By default this is
        `tf.Graphkeys.VARIABLES`, which includes non-trainable variables such
        as moving averages.
    context: Scope or module, identical to or parent of `scope`. If given, this
        will be used as the stripped prefix. By default `None`, which means
        `context=scope`.
    group_sliced_variables: Boolean, if set to True, sliced variables are
       grouped together in the returned map; if set to False, each partition of
       a sliced variable is a separate (key, value) pair.

  Returns:
    Dictionary mapping normalized variable name to `tf.Variable`, or a list
        of `tf.Variables` if the variable is a sliced (partitioned) variable.

  Raises:
    ValueError: If `context` is given but is not a proper prefix of `scope`.
  """
  scope_name = get_variable_scope_name(scope_or_module)

  if context is None:
    context = scope_or_module

  prefix = get_variable_scope_name(context)
  prefix_length = len(prefix) + 1 if prefix else 0

  if not _is_scope_prefix(scope_name, prefix):
    raise ValueError("Scope '{}' is not prefixed by '{}'.".format(
        scope_name, prefix))

  variables = get_variables_in_scope(scope_name, collection)

  if not group_sliced_variables:
    single_vars = variables
    grouped_vars = dict()
  else:
    single_vars, grouped_vars = _get_sliced_variables(variables)

  var_map = {var.op.name[prefix_length:]: var for var in single_vars}
  for full_name, var_group in grouped_vars.items():
    name = full_name[prefix_length:]
    if name in var_map:
      raise ValueError("Mixing slices and non-slices with the same name: " +
                       str(name))
    var_map[name] = var_group
  return var_map
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号