stochastic_gradient_estimators.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def get_mean_baseline(ema_decay=0.99, name=None):
  """ExponentialMovingAverage baseline.

  EMA initializes to 0, which introduces a bias. This baseline implements the
  bias correction term from Adam (section 3 of
  https://arxiv.org/pdf/1412.6980v8.pdf), dividing by `1 - ema_decay^t`, where
  `t` is the step count.

  Args:
    ema_decay: decay rate for the ExponentialMovingAverage.
    name: name for variable scope of the ExponentialMovingAverage.

  Returns:
    Callable baseline function that takes the `DistributionTensor` (unused) and
    the downstream `loss`, and returns an EMA of the loss.
  """

  def mean_baseline(_, loss):
    with vs.variable_scope(name, default_name="MeanBaseline"):
      reduced_loss = math_ops.reduce_mean(loss)

      ema = training.ExponentialMovingAverage(decay=ema_decay)
      update_op = ema.apply([reduced_loss])

      # The bias correction term requires keeping track of how many times the
      # EMA has been updated. Creating a variable here to do so. The global step
      # is not used because it may or may not track exactly the number of times
      # the EMA is updated.
      ema_var = ema.average(reduced_loss)
      assert ema_var is not None
      with ops.colocate_with(ema_var):
        num_updates = vs.get_variable(
            "local_ema_step", initializer=0, trainable=False)
      num_updates = num_updates.assign_add(1)
      bias_correction = 1. - math_ops.pow(ema_decay, math_ops.cast(
          num_updates, reduced_loss.dtype))

      with ops.control_dependencies([update_op]):
        baseline = ema.average(reduced_loss) / bias_correction

      return baseline

  return mean_baseline
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号