def __call__(self, x, gamma=1.0):
grad_name = "GradientReverse%d" % self.num_calls
@ops.RegisterGradient(grad_name)
def _gradients_reverse(op, grad):
return [tf.neg(grad) * gamma]
g = tf.get_default_graph()
with g.gradient_override_map({"Identity": grad_name}):
y = tf.identity(x)
self.num_calls += 1
return y
评论列表
文章目录