base_loss.py 文件源码

python
阅读 44 收藏 0 点赞 0 评论 0

项目:HyperGAN 作者: 255BITS 项目源码 文件源码
def gradient_penalty(self):
        config = self.config
        gan = self.gan
        gradient_penalty = config.gradient_penalty
        if has_attr(gan.inputs, 'gradient_penalty_label'):
            x = gan.inputs.gradient_penalty_label
        else:
            x = gan.inputs.x
        generator = self.generator or gan.generator
        g = generator.sample
        discriminator = self.discriminator or gan.discriminator
        shape = [1 for t in g.get_shape()]
        shape[0] = gan.batch_size()
        uniform_noise = tf.random_uniform(shape=shape,minval=0.,maxval=1.)
        print("[gradient penalty] applying x:", x, "g:", g, "noise:", uniform_noise)
        interpolates = x + uniform_noise * (g - x)
        reused_d = discriminator.reuse(interpolates)
        gradients = tf.gradients(reused_d, [interpolates])[0]
        penalty = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=1))
        penalty = tf.reduce_mean(tf.square(penalty - 1.))
        return float(gradient_penalty) * penalty
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号