def train(BATCH_SIZE):
(X_train, Y_train) = get_data('train')
#print(np.shape(X_train))
X_train = (X_train.astype(np.float32) - 127.5)/127.5
Y_train = (Y_train.astype(np.float32) - 127.5)/127.5
#X_train = X_train.reshape((X_train.shape[0], 1) + X_train.shape[1:])
#Y_train = Y_train.reshape((Y_train.shape[0], 1) + Y_train.shape[1:])
discriminator = discriminator_model()
generator = generator_model()
generator.summary()
discriminator_on_generator = generator_containing_discriminator(generator, discriminator)
d_optim = Adagrad(lr=0.005)
g_optim = Adagrad(lr=0.005)
generator.compile(loss='mse', optimizer="rmsprop")
discriminator_on_generator.compile(loss=[generator_l1_loss,discriminator_on_generator_loss] , optimizer="rmsprop")
discriminator.trainable = True
discriminator.compile(loss=discriminator_loss, optimizer="rmsprop")
for epoch in range(100):
print("Epoch is", epoch)
print("Number of batches", int(X_train.shape[0]/BATCH_SIZE))
for index in range(int(X_train.shape[0]/BATCH_SIZE)):
image_batch = Y_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE]
generated_images = generator.predict(X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE])
if index % 20 == 0:
image = combine_images(generated_images)
image = image*127.5+127.5
image = np.swapaxes(image,0,2)
cv2.imwrite(str(epoch)+"_"+str(index)+".png",image)
#Image.fromarray(image.astype(np.uint8)).save(str(epoch)+"_"+str(index)+".png")
real_pairs = np.concatenate((X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE,:,:,:],image_batch),axis=1)
fake_pairs = np.concatenate((X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE,:,:,:],generated_images),axis=1)
X = np.concatenate((real_pairs,fake_pairs))
y = np.zeros((20,1,64,64)) #[1] * BATCH_SIZE + [0] * BATCH_SIZE
d_loss = discriminator.train_on_batch(X, y)
pred_temp = discriminator.predict(X)
#print(np.shape(pred_temp))
print("batch %d d_loss : %f" % (index, d_loss))
discriminator.trainable = False
g_loss = discriminator_on_generator.train_on_batch(X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE,:,:,:], [image_batch,np.ones((10,1,64,64))] )
discriminator.trainable = True
print("batch %d g_loss : %f" % (index, g_loss[1]))
if index % 20 == 0:
generator.save_weights('generator', True)
discriminator.save_weights('discriminator', True)
评论列表
文章目录