def get_mean_baseline(ema_decay=0.99, name=None):
"""ExponentialMovingAverage baseline.
Args:
ema_decay: decay rate for the ExponentialMovingAverage.
name: name for variable scope of the ExponentialMovingAverage.
Returns:
Callable baseline function that takes the `StochasticTensor` (unused) and
the downstream `loss`, and returns an EMA of the loss.
"""
def mean_baseline(_, loss):
with vs.variable_scope(name, default_name="MeanBaseline"):
reduced_loss = math_ops.reduce_mean(loss)
ema = training.ExponentialMovingAverage(decay=ema_decay)
update_op = ema.apply([reduced_loss])
with ops.control_dependencies([update_op]):
# Using `identity` causes an op to be added in this context, which
# triggers the update. Removing the `identity` means nothing is updated.
baseline = array_ops.identity(ema.average(reduced_loss))
return baseline
return mean_baseline
评论列表
文章目录