dcgan.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:dcgan-tfslim 作者: mqtlam 项目源码 文件源码
def __init__(self, sess, FLAGS):
        """Initialization.

        Args:
            sess: TensorFlow session
            FLAGS: flags object
        """
        # initialize variables
        self.sess = sess
        self.f = FLAGS

        # inputs: real (training) images
        images_shape = [self.f.output_size, self.f.output_size, self.f.c_dim]
        self.real_images = tf.placeholder(tf.float32,
            [None] + images_shape, name="real_images")

        # inputs: z (noise)
        self.z = tf.placeholder(tf.float32, [None, self.f.z_dim], name='z')

        # initialize models
        generator = Generator(FLAGS)
        discriminator = Discriminator(FLAGS)

        # generator network
        self.G = generator(self.z)
        # discriminator network for real images
        self.D_real, self.D_real_logits = discriminator(self.real_images)
        # discriminator network for fake images
        self.D_fake, self.D_fake_logits = discriminator(self.G, reuse=True)

        # losses
        self.d_loss_real = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=self.D_real_logits,
                labels=tf.ones_like(self.D_real))
            )
        self.d_loss_fake = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=self.D_fake_logits,
                labels=tf.zeros_like(self.D_fake))
            )
        self.d_loss = self.d_loss_real + self.d_loss_fake
        self.g_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=self.D_fake_logits,
                labels=tf.ones_like(self.D_fake))
            )

        # create summaries
        self.__create_summaries()

        # organize variables
        t_vars = tf.trainable_variables()
        self.d_vars = [var for var in t_vars if "d/" in var.name]
        self.g_vars = [var for var in t_vars if "g/" in var.name]

        # saver
        self.saver = tf.train.Saver()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号