def _MatMulGradMom(op, W, out_grad, batch_size, mom=2):
"""Computes gradient moment for a weight matrix through a MatMul operation.
Assumes ``Z=tf.matmul(A, W)``, where ``W`` is a d1xd2 weight matrix, ``A``
are the nxd1 activations of the previous layer (n being the batch size).
``out_grad`` is the gradient w.r.t. ``Z``, as computed by ``tf.gradients()``.
No transposes in the MatMul operation allowed.
Inputs:
:op: The MatMul operation
:W: The weight matrix (the tensor, not the variable)
:out_grad: The tensor of gradient w.r.t. to the output of the op
:batch_size: Batch size n (constant integer or scalar int tf.Tensor)
:mom: Integer moment desired (defaults to 2)"""
assert op.type == "MatMul"
t_a, t_b = op.get_attr("transpose_a"), op.get_attr("transpose_b")
assert W is op.inputs[1] and not t_a and not t_b
A = op.inputs[0]
out_grad_pow = tf.pow(out_grad, mom)
A_pow = tf.pow(A, mom)
return tf.mul(batch_size, tf.matmul(A_pow, out_grad_pow, transpose_a=True))
gradient_moment.py 文件源码
python
阅读 31
收藏 0
点赞 0
评论 0
评论列表
文章目录