def __init__(self, sess, FLAGS):
"""Initialization.
Args:
sess: TensorFlow session
FLAGS: flags object
"""
# initialize variables
self.sess = sess
self.f = FLAGS
# inputs: real (training) images
images_shape = [self.f.output_size, self.f.output_size, self.f.c_dim]
self.real_images = tf.placeholder(tf.float32,
[None] + images_shape, name="real_images")
# inputs: z (noise)
self.z = tf.placeholder(tf.float32, [None, self.f.z_dim], name='z')
# initialize models
generator = Generator(FLAGS)
discriminator = Discriminator(FLAGS)
# generator network
self.G = generator(self.z)
# discriminator network for real images
self.D_real, self.D_real_logits = discriminator(self.real_images)
# discriminator network for fake images
self.D_fake, self.D_fake_logits = discriminator(self.G, reuse=True)
# losses
self.d_loss_real = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(
logits=self.D_real_logits,
labels=tf.ones_like(self.D_real))
)
self.d_loss_fake = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(
logits=self.D_fake_logits,
labels=tf.zeros_like(self.D_fake))
)
self.d_loss = self.d_loss_real + self.d_loss_fake
self.g_loss = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(
logits=self.D_fake_logits,
labels=tf.ones_like(self.D_fake))
)
# create summaries
self.__create_summaries()
# organize variables
t_vars = tf.trainable_variables()
self.d_vars = [var for var in t_vars if "d/" in var.name]
self.g_vars = [var for var in t_vars if "g/" in var.name]
# saver
self.saver = tf.train.Saver()
评论列表
文章目录