GAN_models.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:WassersteinGAN.tensorflow 作者: shekkizh 项目源码 文件源码
def _gan_loss(self, logits_real, logits_fake, feature_real, feature_fake, use_features=False):
        discriminator_loss_real = self._cross_entropy_loss(logits_real, tf.ones_like(logits_real),
                                                           name="disc_real_loss")

        discriminator_loss_fake = self._cross_entropy_loss(logits_fake, tf.zeros_like(logits_fake),
                                                           name="disc_fake_loss")
        self.discriminator_loss = discriminator_loss_fake + discriminator_loss_real

        gen_loss_disc = self._cross_entropy_loss(logits_fake, tf.ones_like(logits_fake), name="gen_disc_loss")
        if use_features:
            gen_loss_features = tf.reduce_mean(tf.nn.l2_loss(feature_real - feature_fake)) / (self.crop_image_size ** 2)
        else:
            gen_loss_features = 0
        self.gen_loss = gen_loss_disc + 0.1 * gen_loss_features

        tf.scalar_summary("Discriminator_loss", self.discriminator_loss)
        tf.scalar_summary("Generator_loss", self.gen_loss)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号