def create(self):
config = self.config
d2 = dict(config.discriminator)
d2['class'] = self.ops.lookup("class:hypergan.discriminators.pyramid_discriminator.PyramidDiscriminator")
self.encoder = self.create_component(d2)
self.encoder.ops.describe("encoder")
self.encoder.create(self.inputs.x)
self.encoder.z = tf.zeros(0)
self.trainer = self.create_component(config.trainer)
StandardGAN.create(self)
cycloss = tf.reduce_mean(tf.abs(self.inputs.x-self.generator.sample))
cycloss_lambda = config.cycloss_lambda or 10
self.loss.sample[1] *= config.g_lambda or 1
self.loss.sample[1] += cycloss*cycloss_lambda
self.trainer.create()
self.session.run(tf.global_variables_initializer())
评论列表
文章目录