def get_mean_baseline(ema_decay=0.99, name=None):
"""ExponentialMovingAverage baseline.
EMA initializes to 0, which introduces a bias. This baseline implements the
bias correction term from Adam (section 3 of
https://arxiv.org/pdf/1412.6980v8.pdf), dividing by `1 - ema_decay^t`, where
`t` is the step count.
Args:
ema_decay: decay rate for the ExponentialMovingAverage.
name: name for variable scope of the ExponentialMovingAverage.
Returns:
Callable baseline function that takes the `DistributionTensor` (unused) and
the downstream `loss`, and returns an EMA of the loss.
"""
def mean_baseline(_, loss):
with vs.variable_scope(name, default_name="MeanBaseline"):
reduced_loss = math_ops.reduce_mean(loss)
ema = training.ExponentialMovingAverage(decay=ema_decay)
update_op = ema.apply([reduced_loss])
# The bias correction term requires keeping track of how many times the
# EMA has been updated. Creating a variable here to do so. The global step
# is not used because it may or may not track exactly the number of times
# the EMA is updated.
ema_var = ema.average(reduced_loss)
assert ema_var is not None
with ops.colocate_with(ema_var):
num_updates = vs.get_variable(
"local_ema_step", initializer=0, trainable=False)
num_updates = num_updates.assign_add(1)
bias_correction = 1. - math_ops.pow(ema_decay, math_ops.cast(
num_updates, reduced_loss.dtype))
with ops.control_dependencies([update_op]):
baseline = ema.average(reduced_loss) / bias_correction
return baseline
return mean_baseline
评论列表
文章目录