gradient_moment.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:probabilistic_line_search 作者: ProbabilisticNumerics 项目源码 文件源码
def _MatMulGradMom(op, W, out_grad, batch_size, mom=2):
  """Computes gradient moment for a weight matrix through a MatMul operation.

  Assumes ``Z=tf.matmul(A, W)``, where ``W`` is a d1xd2 weight matrix, ``A``
  are the nxd1 activations of the previous layer (n being the batch size).
  ``out_grad`` is the gradient w.r.t. ``Z``, as computed by ``tf.gradients()``.
  No transposes in the MatMul operation allowed.

  Inputs:
      :op: The MatMul operation
      :W: The weight matrix (the tensor, not the variable)
      :out_grad: The tensor of gradient w.r.t. to the output of the op
      :batch_size: Batch size n (constant integer or scalar int tf.Tensor)
      :mom: Integer moment desired (defaults to 2)"""

  assert op.type == "MatMul"
  t_a, t_b = op.get_attr("transpose_a"), op.get_attr("transpose_b")
  assert W is op.inputs[1] and not t_a and not t_b

  A = op.inputs[0]
  out_grad_pow = tf.pow(out_grad, mom)
  A_pow = tf.pow(A, mom)
  return tf.mul(batch_size, tf.matmul(A_pow, out_grad_pow, transpose_a=True))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号