WGAN_GP_Char.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:GAN-general 作者: weilinie 项目源码 文件源码
def train(self):
        z_fixed = np.random.normal(size=[self.batch_size*10, self.z_dim]) # samples of 10 times batch size
        gen = inf_train_gen(self.lines, self.batch_size, self.charmap)

        for step in trange(self.max_step):
            # Train generator
            _data = gen.next()
            summary_str, _ = self.sess.run([self.summary_op, self.g_optim], feed_dict={self.real_data: _data})
            self.summary_writer.add_summary(summary_str, global_step=step)
            self.summary_writer.flush()

            # Train critic
            for i in range(self.critic_iters):
                _data = gen.next()
                self.sess.run(self.d_optim, feed_dict={self.real_data: _data})

            if step % 100 == 99:
                _data = gen.next()
                g_loss, d_loss, slope = self.sess.run([self.g_loss, self.d_loss, self.slope],
                                                      feed_dict={self.real_data: _data})
                print("[{}/{}] Loss_D: {:.6f} Loss_G: {:.6f} Slope: {:.6f}".
                      format(step+1, self.max_step, d_loss, g_loss, slope))
                self.generate_samples(z_fixed, idx=step+1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号