infogan.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:Keras-GAN 作者: eriklindernoren 项目源码 文件源码
def train(self, epochs, batch_size=128, save_interval=50):

        # Load the dataset
        (X_train, y_train), (_, _) = mnist.load_data()

        # Rescale -1 to 1
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_train = np.expand_dims(X_train, axis=3)
        y_train = y_train.reshape(-1, 1)

        half_batch = int(batch_size / 2)

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Train discriminator on generator output
            sampled_noise, sampled_labels, sampled_cont = self.sample_generator_input(half_batch)
            gen_input = np.concatenate((sampled_noise, sampled_labels, sampled_cont), axis=1)
            # Generate a half batch of new images
            gen_imgs = self.generator.predict(gen_input)
            fake = np.zeros((half_batch, 1))
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, [fake, sampled_labels, sampled_cont])

            # Train discriminator on real data
            # Select a random half batch of images
            idx = np.random.randint(0, X_train.shape[0], half_batch)
            imgs = X_train[idx]
            labels = to_categorical(y_train[idx], num_classes=self.num_classes)
            valid = np.ones((half_batch, 1))
            d_loss_real = self.discriminator.train_on_batch(imgs, [valid, labels, sampled_cont])

            # Avg. loss
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            valid = np.ones((batch_size, 1))

            sampled_noise, sampled_labels, sampled_cont = self.sample_generator_input(batch_size)
            gen_input = np.concatenate((sampled_noise, sampled_labels, sampled_cont), axis=1)

            # Train the generator
            g_loss = self.combined.train_on_batch(gen_input, [valid, sampled_labels, sampled_cont])

            # Plot the progress
            print ("%d [D loss: %.2f, acc.: %.2f%%, label_acc: %.2f%%] [G loss: %.2f]" % (epoch, d_loss[0], 100*d_loss[4], 100*d_loss[5], g_loss[0]))

            # If at save interval => save generated image samples
            if epoch % save_interval == 0:
                self.save_imgs(epoch)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号