def _GradMom(op, v, out_grad, batch_size, mom=2):
"""Wrapper function for the operation type-specific GradMom functions below.
Inputs:
:op: A tensorflow operation of type in VALID_TYPES.
:v: The read-tensor of the trainable variable consumed by this operation.
:out_grad: The tensor containing the gradient w.r.t. to the output of
the op (as computed by ``tf.gradients``).
:batch_size: Batch size ``m`` (constant integer or scalar int tf.Tensor)
:mom: Integer moment desired (defaults to 2)."""
with tf.name_scope(op.name+"_grad_mom"):
if op.type == "MatMul":
return _MatMulGradMom(op, v, out_grad, batch_size, mom)
elif op.type == "Conv2D":
return _Conv2DGradMom(op, v, out_grad, batch_size, mom)
elif op.type == "Add":
return _AddGradMom(op, v, out_grad, batch_size, mom)
else:
raise ValueError("Don't know how to compute gradient moment for "
"variable {}, consumed by operation of type {}".format(v.name,
op.type))
gradient_moment.py 文件源码
python
阅读 32
收藏 0
点赞 0
评论 0
评论列表
文章目录