ops.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:DMNN 作者: magnux 项目源码 文件源码
def tf_mod(x, y, name=None):
    """Differentiable mod based in numpy
    Args
        x: first argument
        y: second argument
    Returns
        mod between x and y
    """

    def np_mod(x, y):
        return np.mod(x, y, dtype=np.float32)

    def modgrad(op, grad):
        x = op.inputs[0] # the first argument (normally you need those to calculate the gradient, like the gradient of x^2 is 2x. )
        y = op.inputs[1] # the second argument

        return grad * 1, grad * 0 #the propagated gradient with respect to the first and second argument respectively

    def py_func(func, inp, Tout, stateful=True, name=None, grad=None):
        # Need to generate a unique name to avoid duplicates:
        rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1E+8))

        tf.RegisterGradient(rnd_name)(grad)  # see _MySquareGrad for grad example
        g = tf.get_default_graph()
        with g.gradient_override_map({"PyFunc": rnd_name}):
            return tf.py_func(func, inp, Tout, stateful=stateful, name=name)

    with ops.name_scope(name, "mod", [x,y]) as name:
        z = py_func(np_mod,
                    [x,y],
                    [tf.float32],
                    name=name,
                    grad=modgrad)  # <-- here's the call to the gradient
        return tf.reshape(z[0], tf.shape(x))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号