def train(self, epochs, batch_size=128, save_interval=50):
mnist = fetch_mldata('MNIST original')
X = mnist.data.reshape((-1,) + self.img_shape)
y = mnist.target
# Rescale -1 to 1
X = (X.astype(np.float32) - 127.5) / 127.5
half_batch = int(batch_size / 2)
for epoch in range(epochs):
# ---------------------
# Train Discriminator
# ---------------------
self.discriminator.set_trainable(True)
# Select a random half batch of images
idx = np.random.randint(0, X.shape[0], half_batch)
imgs = X[idx]
noise = np.random.normal(0, 1, (half_batch, 100))
# Generate a half batch of images
gen_imgs = self.generator.predict(noise)
valid = np.concatenate((np.ones((half_batch, 1)), np.zeros((half_batch, 1))), axis=1)
fake = np.concatenate((np.zeros((half_batch, 1)), np.ones((half_batch, 1))), axis=1)
# Train the discriminator
d_loss_real, d_acc_real = self.discriminator.train_on_batch(imgs, valid)
d_loss_fake, d_acc_fake = self.discriminator.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * (d_loss_real + d_loss_fake)
d_acc = 0.5 * (d_acc_real + d_acc_fake)
# ---------------------
# Train Generator
# ---------------------
# We only want to train the generator for the combined model
self.discriminator.set_trainable(False)
# Sample noise and use as generator input
noise = np.random.normal(0, 1, (batch_size, 100))
# The generator wants the discriminator to label the generated samples as valid
valid = np.concatenate((np.ones((batch_size, 1)), np.zeros((batch_size, 1))), axis=1)
# Train the generator
g_loss, g_acc = self.combined.train_on_batch(noise, valid)
# Display the progress
print ("%d [D loss: %f, acc: %.2f%%] [G loss: %f, acc: %.2f%%]" % (epoch, d_loss, 100*d_acc, g_loss, 100*g_acc))
# If at save interval => save generated image samples
if epoch % save_interval == 0:
self.save_imgs(epoch)
评论列表
文章目录