def train(self, epochs, batch_size=128, save_interval=50):
# Load the dataset
(X_train, y_train), (_, _) = mnist.load_data()
# Rescale -1 to 1
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=3)
y_train = y_train.reshape(-1, 1)
half_batch = int(batch_size / 2)
for epoch in range(epochs):
# ---------------------
# Train Discriminator
# ---------------------
# Train discriminator on generator output
sampled_noise, sampled_labels, sampled_cont = self.sample_generator_input(half_batch)
gen_input = np.concatenate((sampled_noise, sampled_labels, sampled_cont), axis=1)
# Generate a half batch of new images
gen_imgs = self.generator.predict(gen_input)
fake = np.zeros((half_batch, 1))
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, [fake, sampled_labels, sampled_cont])
# Train discriminator on real data
# Select a random half batch of images
idx = np.random.randint(0, X_train.shape[0], half_batch)
imgs = X_train[idx]
labels = to_categorical(y_train[idx], num_classes=self.num_classes)
valid = np.ones((half_batch, 1))
d_loss_real = self.discriminator.train_on_batch(imgs, [valid, labels, sampled_cont])
# Avg. loss
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# ---------------------
# Train Generator
# ---------------------
valid = np.ones((batch_size, 1))
sampled_noise, sampled_labels, sampled_cont = self.sample_generator_input(batch_size)
gen_input = np.concatenate((sampled_noise, sampled_labels, sampled_cont), axis=1)
# Train the generator
g_loss = self.combined.train_on_batch(gen_input, [valid, sampled_labels, sampled_cont])
# Plot the progress
print ("%d [D loss: %.2f, acc.: %.2f%%, label_acc: %.2f%%] [G loss: %.2f]" % (epoch, d_loss[0], 100*d_loss[4], 100*d_loss[5], g_loss[0]))
# If at save interval => save generated image samples
if epoch % save_interval == 0:
self.save_imgs(epoch)
评论列表
文章目录