resample.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def weighted_resample(inputs, weights, overall_rate, scope=None,
                      mean_decay=0.999, warmup=10, seed=None):
  """Performs an approximate weighted resampling of `inputs`.

  This method chooses elements from `inputs` where each item's rate of
  selection is proportional to its value in `weights`, and the average
  rate of selection across all inputs (and many invocations!) is
  `overall_rate`.

  Args:
    inputs: A list of tensors whose first dimension is `batch_size`.
    weights: A `[batch_size]`-shaped tensor with each batch member's weight.
    overall_rate: Desired overall rate of resampling.
    scope: Scope to use for the op.
    mean_decay: How quickly to decay the running estimate of the mean weight.
    warmup: Until the resulting tensor has been evaluated `warmup`
      times, the resampling menthod uses the true mean over all calls
      as its weight estimate, rather than a decayed mean.
    seed: Random seed.

  Returns:
    A list of tensors exactly like `inputs`, but with an unknown (and
      possibly zero) first dimension.
    A tensor containing the effective resampling rate used for each output.

  """
  # Algorithm: Just compute rates as weights/mean_weight *
  # overall_rate. This way the the average weight corresponds to the
  # overall rate, and a weight twice the average has twice the rate,
  # etc.
  with ops.name_scope(scope, 'weighted_resample', inputs) as opscope:
    # First: Maintain a running estimated mean weight, with decay
    # adjusted (by also maintaining an invocation count) during the
    # warmup period so that at the beginning, there aren't too many
    # zeros mixed in, throwing the average off.

    with variable_scope.variable_scope(scope, 'estimate_mean', inputs):
      count_so_far = variable_scope.get_local_variable(
          'resample_count', initializer=0)

      estimated_mean = variable_scope.get_local_variable(
          'estimated_mean', initializer=0.0)

      count = count_so_far.assign_add(1)
      real_decay = math_ops.minimum(
          math_ops.truediv((count - 1), math_ops.minimum(count, warmup)),
          mean_decay)

      batch_mean = math_ops.reduce_mean(weights)
      mean = moving_averages.assign_moving_average(
          estimated_mean, batch_mean, real_decay, zero_debias=False)

    # Then, normalize the weights into rates using the mean weight and
    # overall target rate:
    rates = weights * overall_rate / mean

    results = resample_at_rate([rates] + inputs, rates,
                               scope=opscope, seed=seed, back_prop=False)

    return (results[1:], results[0])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号