def weighted_resample(inputs, weights, overall_rate, scope=None,
mean_decay=0.999, warmup=10, seed=None):
"""Performs an approximate weighted resampling of `inputs`.
This method chooses elements from `inputs` where each item's rate of
selection is proportional to its value in `weights`, and the average
rate of selection across all inputs (and many invocations!) is
`overall_rate`.
Args:
inputs: A list of tensors whose first dimension is `batch_size`.
weights: A `[batch_size]`-shaped tensor with each batch member's weight.
overall_rate: Desired overall rate of resampling.
scope: Scope to use for the op.
mean_decay: How quickly to decay the running estimate of the mean weight.
warmup: Until the resulting tensor has been evaluated `warmup`
times, the resampling menthod uses the true mean over all calls
as its weight estimate, rather than a decayed mean.
seed: Random seed.
Returns:
A list of tensors exactly like `inputs`, but with an unknown (and
possibly zero) first dimension.
A tensor containing the effective resampling rate used for each output.
"""
# Algorithm: Just compute rates as weights/mean_weight *
# overall_rate. This way the the average weight corresponds to the
# overall rate, and a weight twice the average has twice the rate,
# etc.
with ops.name_scope(scope, 'weighted_resample', inputs) as opscope:
# First: Maintain a running estimated mean weight, with decay
# adjusted (by also maintaining an invocation count) during the
# warmup period so that at the beginning, there aren't too many
# zeros mixed in, throwing the average off.
with variable_scope.variable_scope(scope, 'estimate_mean', inputs):
count_so_far = variable_scope.get_local_variable(
'resample_count', initializer=0)
estimated_mean = variable_scope.get_local_variable(
'estimated_mean', initializer=0.0)
count = count_so_far.assign_add(1)
real_decay = math_ops.minimum(
math_ops.truediv((count - 1), math_ops.minimum(count, warmup)),
mean_decay)
batch_mean = math_ops.reduce_mean(weights)
mean = moving_averages.assign_moving_average(
estimated_mean, batch_mean, real_decay, zero_debias=False)
# Then, normalize the weights into rates using the mean weight and
# overall target rate:
rates = weights * overall_rate / mean
results = resample_at_rate([rates] + inputs, rates,
scope=opscope, seed=seed, back_prop=False)
return (results[1:], results[0])
评论列表
文章目录