GAN_GP_Img.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:GAN-general 作者: weilinie 项目源码 文件源码
def cal_one_side_grad_penalty(self, real_data, fake_data):
        # WGAN lipschitz-penalty
        epsilon = tf.random_uniform(shape=[self.batch_size, 1, 1], minval=0., maxval=1.)

        data_diff = fake_data - real_data
        interp_data = real_data + epsilon * data_diff
        disc_interp, _ = discriminator(
            self.d_net, interp_data, self.conv_hidden_num,
            self.normalize_d
        )
        grad_interp = tf.gradients(disc_interp, [interp_data])[0]
        print('The shape of grad_interp: {}'.format(grad_interp.get_shape().as_list()))
        grad_interp_flat = tf.reshape(grad_interp, [self.batch_size, -1])
        slope = tf.norm(grad_interp_flat, axis=1)
        print('The shape of slope: {}'.format(slope.get_shape().as_list()))

        grad_penalty = tf.reduce_mean(tf.nn.relu(slope - 1.) ** 2)
        return grad_penalty
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号