def gradient_penalty(self):
config = self.config
gan = self.gan
gradient_penalty = config.gradient_penalty
if has_attr(gan.inputs, 'gradient_penalty_label'):
x = gan.inputs.gradient_penalty_label
else:
x = gan.inputs.x
generator = self.generator or gan.generator
g = generator.sample
discriminator = self.discriminator or gan.discriminator
shape = [1 for t in g.get_shape()]
shape[0] = gan.batch_size()
uniform_noise = tf.random_uniform(shape=shape,minval=0.,maxval=1.)
print("[gradient penalty] applying x:", x, "g:", g, "noise:", uniform_noise)
interpolates = x + uniform_noise * (g - x)
reused_d = discriminator.reuse(interpolates)
gradients = tf.gradients(reused_d, [interpolates])[0]
penalty = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=1))
penalty = tf.reduce_mean(tf.square(penalty - 1.))
return float(gradient_penalty) * penalty
评论列表
文章目录