def prepare_trainer(self, generator_loss, discriminator_loss):
'''Helper function for init_opt'''
all_vars = tf.trainable_variables()
g_vars = [var for var in all_vars if
var.name.startswith('g_')]
d_vars = [var for var in all_vars if
var.name.startswith('d_')]
generator_opt = tf.train.AdamOptimizer(self.generator_lr,
beta1=0.5)
self.generator_trainer =\
pt.apply_optimizer(generator_opt,
losses=[generator_loss],
var_list=g_vars)
discriminator_opt = tf.train.AdamOptimizer(self.discriminator_lr,
beta1=0.5)
self.discriminator_trainer =\
pt.apply_optimizer(discriminator_opt,
losses=[discriminator_loss],
var_list=d_vars)
self.log_vars.append(("g_learning_rate", self.generator_lr))
self.log_vars.append(("d_learning_rate", self.discriminator_lr))
评论列表
文章目录