def tf_mod(x, y, name=None):
"""Differentiable mod based in numpy
Args
x: first argument
y: second argument
Returns
mod between x and y
"""
def np_mod(x, y):
return np.mod(x, y, dtype=np.float32)
def modgrad(op, grad):
x = op.inputs[0] # the first argument (normally you need those to calculate the gradient, like the gradient of x^2 is 2x. )
y = op.inputs[1] # the second argument
return grad * 1, grad * 0 #the propagated gradient with respect to the first and second argument respectively
def py_func(func, inp, Tout, stateful=True, name=None, grad=None):
# Need to generate a unique name to avoid duplicates:
rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1E+8))
tf.RegisterGradient(rnd_name)(grad) # see _MySquareGrad for grad example
g = tf.get_default_graph()
with g.gradient_override_map({"PyFunc": rnd_name}):
return tf.py_func(func, inp, Tout, stateful=stateful, name=name)
with ops.name_scope(name, "mod", [x,y]) as name:
z = py_func(np_mod,
[x,y],
[tf.float32],
name=name,
grad=modgrad) # <-- here's the call to the gradient
return tf.reshape(z[0], tf.shape(x))
评论列表
文章目录