stochastic_gradient_estimators.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def get_mean_baseline(ema_decay=0.99, name=None):
  """ExponentialMovingAverage baseline.

  Args:
    ema_decay: decay rate for the ExponentialMovingAverage.
    name: name for variable scope of the ExponentialMovingAverage.

  Returns:
    Callable baseline function that takes the `StochasticTensor` (unused) and
    the downstream `loss`, and returns an EMA of the loss.
  """

  def mean_baseline(_, loss):
    with vs.variable_scope(name, default_name="MeanBaseline"):
      reduced_loss = math_ops.reduce_mean(loss)

      ema = training.ExponentialMovingAverage(decay=ema_decay, zero_debias=True)
      update_op = ema.apply([reduced_loss])

      with ops.control_dependencies([update_op]):
        # Using `identity` causes an op to be added in this context, which
        # triggers the update. Removing the `identity` means nothing is updated.
        baseline = array_ops.identity(ema.average(reduced_loss))

      return baseline

  return mean_baseline
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号