layers.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def __call__(self, x, gamma=1.0):
        grad_name = "GradientReverse%d" % self.num_calls

        @ops.RegisterGradient(grad_name)
        def _gradients_reverse(op, grad):
            return [tf.neg(grad) * gamma]

        g = tf.get_default_graph()
        with g.gradient_override_map({"Identity": grad_name}):
            y = tf.identity(x)

        self.num_calls += 1
        return y
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号