Gan.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:ICGan-tensorflow 作者: zhangqianhui 项目源码 文件源码
def build_model1(self):

        #Constructing the Gan
        #Get the variables

        self.fake_images = self.generate(self.z, self.y, weights=self.weights1, biases=self.biases1)

        # the loss of dis network
        self.D_pro = self.discriminate(self.images, self.y, self.weights2, self.biases2, False)

        self.G_pro = self.discriminate(self.fake_images, self.y, self.weights2, self.biases2, True)

        self.G_fake_loss = -tf.reduce_mean(tf.log(self.G_pro + TINY))
        self.loss = -tf.reduce_mean(tf.log(1. - self.G_pro + TINY) + tf.log(self.D_pro + TINY))

        self.log_vars.append(("generator_loss", self.G_fake_loss))
        self.log_vars.append(("discriminator_loss", self.loss))

        t_vars = tf.trainable_variables()

        self.d_vars = [var for var in t_vars if 'dis' in var.name]
        self.g_vars = [var for var in t_vars if 'gen' in var.name]

        self.saver = tf.train.Saver(self.g_vars)

        for k, v in self.log_vars:
            tf.summary.scalar(k, v)

    #Training the Encode_z
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号