def train(name, model, callbacks=None, batch_size=32, nb_epoch=200):
"""Common cifar10 training code.
"""
callbacks = callbacks or []
tb = TensorBoard(log_dir='./logs/{}'.format(name))
model_checkpoint = ModelCheckpoint('./weights/{}.hdf5'.format(name), monitor='val_loss', save_best_only=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, min_lr=1e-7)
callbacks.extend([reduce_lr, tb, model_checkpoint])
print("Training {}".format(name))
# This will do preprocessing and realtime data augmentation:
datagen = ImageDataGenerator(
featurewise_center=False, # set input mean to 0 over the dataset
samplewise_center=False, # set each sample mean to 0
featurewise_std_normalization=False, # divide inputs by std of the dataset
samplewise_std_normalization=False, # divide each input by its std
zca_whitening=False, # apply ZCA whitening
rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180)
width_shift_range=0.1, # randomly shift images horizontally (fraction of total width)
height_shift_range=0.1, # randomly shift images vertically (fraction of total height)
horizontal_flip=True, # randomly flip images
vertical_flip=False) # randomly flip images
# Compute quantities required for feature-wise normalization
# (std, mean, and principal components if ZCA whitening is applied).
datagen.fit(X_train)
# Fit the model on the batches generated by datagen.flow().
model.fit_generator(datagen.flow(X_train, Y_train, batch_size=batch_size),
samples_per_epoch=X_train.shape[0],
nb_epoch=nb_epoch, verbose=2, max_q_size=1000,
callbacks=callbacks, validation_data=(X_test, Y_test))
评论列表
文章目录