improved_WGAN.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:Conditional-GAN 作者: m516825 项目源码 文件源码
def train(self):
        batch_num = self.data.length//self.FLAGS.batch_size if self.data.length%self.FLAGS.batch_size==0 else self.data.length//self.FLAGS.batch_size + 1

        print("Start training WGAN...\n")

        for t in range(self.FLAGS.iter):

            d_cost = 0
            g_coat = 0

            for d_ep in range(self.d_epoch):

                img, tags, _, w_img, w_tags = self.data.next_data_batch(self.FLAGS.batch_size)
                z = self.data.next_noise_batch(len(tags), self.FLAGS.z_dim)

                feed_dict = {
                    self.seq:tags,
                    self.img:img,
                    self.z:z,
                    self.w_seq:w_tags,
                    self.w_img:w_img
                }

                _, loss = self.sess.run([self.d_updates, self.d_loss], feed_dict=feed_dict)

                d_cost += loss/self.d_epoch

            z = self.data.next_noise_batch(len(tags), self.FLAGS.z_dim)
            feed_dict = {
                self.img:img,
                self.w_seq:w_tags,
                self.w_img:w_img,
                self.seq:tags,
                self.z:z
            }

            _, loss, step = self.sess.run([self.g_updates, self.g_loss, self.global_step], feed_dict=feed_dict)

            current_step = tf.train.global_step(self.sess, self.global_step)

            g_cost = loss

            if current_step % self.FLAGS.display_every == 0:
                print("Epoch {}, Current_step {}".format(self.data.epoch, current_step))
                print("Discriminator loss :{}".format(d_cost))
                print("Generator loss     :{}".format(g_cost))
                print("---------------------------------")

            if current_step % self.FLAGS.checkpoint_every == 0:
                path = self.saver.save(self.sess, self.checkpoint_prefix, global_step=current_step)
                print ("\nSaved model checkpoint to {}\n".format(path))

            if current_step % self.FLAGS.dump_every == 0:
                self.eval(current_step)
                print("Dump test image")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号