mnist_gan.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:deep-learning 作者: ljanyst 项目源码 文件源码
def get_optimizers(self, learning_rate=0.002, smooth=0.1):
        #-----------------------------------------------------------------------
        # Define loss functions
        #-----------------------------------------------------------------------
        with tf.variable_scope('loses'):
            dsc_real_loss = tf.reduce_mean(
              tf.nn.sigmoid_cross_entropy_with_logits(
                logits=self.dsc_real_logits,
                labels=tf.ones_like(self.dsc_real_logits) * (1 - smooth)))

            dsc_fake_loss = tf.reduce_mean(
              tf.nn.sigmoid_cross_entropy_with_logits(
                logits=self.dsc_fake_logits,
                labels=tf.zeros_like(self.dsc_fake_logits)))

            dsc_loss = (dsc_real_loss + dsc_fake_loss)/2

            gen_loss = tf.reduce_mean(
              tf.nn.sigmoid_cross_entropy_with_logits(
                logits=self.dsc_fake_logits,
                labels=tf.ones_like(self.dsc_fake_logits)))

        #-----------------------------------------------------------------------
        # Optimizers
        #-----------------------------------------------------------------------
        trainable_vars = tf.trainable_variables()
        gen_vars = [var for var in trainable_vars \
                      if var.name.startswith('generator')]
        dsc_vars = [var for var in trainable_vars \
                      if var.name.startswith('discriminator')]

        with tf.variable_scope('optimizers'):
            with tf.variable_scope('deiscriminator_optimizer'):
                dsc_train_opt = tf.train.AdamOptimizer(learning_rate) \
                  .minimize(dsc_loss, var_list=dsc_vars)
            with tf.variable_scope('generator_optimizer'):
                gen_train_opt = tf.train.AdamOptimizer(learning_rate) \
                  .minimize(gen_loss, var_list=gen_vars)

        return dsc_train_opt, gen_train_opt, dsc_loss, gen_loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号