gradient_moment.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:probabilistic_line_search 作者: ProbabilisticNumerics 项目源码 文件源码
def _GradMom(op, v, out_grad, batch_size, mom=2):
  """Wrapper function for the operation type-specific GradMom functions below.

  Inputs:
      :op: A tensorflow operation of type in VALID_TYPES.
      :v: The read-tensor of the trainable variable consumed by this operation.
      :out_grad: The tensor containing the gradient w.r.t. to the output of
          the op (as computed by ``tf.gradients``).
      :batch_size: Batch size ``m`` (constant integer or scalar int tf.Tensor)
      :mom: Integer moment desired (defaults to 2)."""

  with tf.name_scope(op.name+"_grad_mom"):
    if op.type == "MatMul":
      return _MatMulGradMom(op, v, out_grad, batch_size, mom)
    elif op.type == "Conv2D":
      return _Conv2DGradMom(op, v, out_grad, batch_size, mom)
    elif op.type == "Add":
      return _AddGradMom(op, v, out_grad, batch_size, mom)
    else:
      raise ValueError("Don't know how to compute gradient moment for "
          "variable {}, consumed by operation of type {}".format(v.name,
          op.type))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号