gan.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:vae-npvc 作者: JeremyCCHsu 项目源码 文件源码
def _optimize(self):
        '''
        NOTE: The author said that there was no need for 100 d_iter per 100 iters. 
              https://github.com/igul222/improved_wgan_training/issues/3
        '''
        global_step = tf.Variable(0, name='global_step')
        lr = self.arch['training']['lr']
        b1 = self.arch['training']['beta1']
        b2 = self.arch['training']['beta2']

        optimizer = tf.train.AdamOptimizer(lr, b1, b2)

        trainables = tf.trainable_variables()
        g_vars = trainables
        # g_vars = [v for v in trainables if 'Generator' in v.name or 'y_emb' in v.name]

        with tf.name_scope('Update'):        
            opt_g = optimizer.minimize(self.loss['G'], var_list=g_vars, global_step=global_step)
        return {
            'g': opt_g,
            'global_step': global_step
        }
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号