cifar10.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:deep-learning-experiments 作者: raghakot 项目源码 文件源码
def train(name, model, callbacks=None, batch_size=32, nb_epoch=200):
    """Common cifar10 training code.
    """
    callbacks = callbacks or []
    tb = TensorBoard(log_dir='./logs/{}'.format(name))
    model_checkpoint = ModelCheckpoint('./weights/{}.hdf5'.format(name), monitor='val_loss', save_best_only=True)
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, min_lr=1e-7)
    callbacks.extend([reduce_lr, tb, model_checkpoint])

    print("Training {}".format(name))

    # This will do preprocessing and realtime data augmentation:
    datagen = ImageDataGenerator(
        featurewise_center=False,  # set input mean to 0 over the dataset
        samplewise_center=False,  # set each sample mean to 0
        featurewise_std_normalization=False,  # divide inputs by std of the dataset
        samplewise_std_normalization=False,  # divide each input by its std
        zca_whitening=False,  # apply ZCA whitening
        rotation_range=0,  # randomly rotate images in the range (degrees, 0 to 180)
        width_shift_range=0.1,  # randomly shift images horizontally (fraction of total width)
        height_shift_range=0.1,  # randomly shift images vertically (fraction of total height)
        horizontal_flip=True,  # randomly flip images
        vertical_flip=False)  # randomly flip images

    # Compute quantities required for feature-wise normalization
    # (std, mean, and principal components if ZCA whitening is applied).
    datagen.fit(X_train)

    # Fit the model on the batches generated by datagen.flow().
    model.fit_generator(datagen.flow(X_train, Y_train, batch_size=batch_size),
                        samples_per_epoch=X_train.shape[0],
                        nb_epoch=nb_epoch, verbose=2, max_q_size=1000,
                        callbacks=callbacks, validation_data=(X_test, Y_test))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号