train.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:Digit-Classifier 作者: ardamavi 项目源码 文件源码
def train_model(model, X, X_test, Y, Y_test):

    batch_size = 100
    epochs = 2

    checkpoints = []
    if not os.path.exists('Data/Checkpoints/'):
        os.makedirs('Data/Checkpoints/')
    checkpoints.append(ModelCheckpoint('Data/Checkpoints/best_weights.h5', monitor='val_loss', verbose=0, save_best_only=True, save_weights_only=True, mode='auto', period=1))
    checkpoints.append(TensorBoard(log_dir='Data/Checkpoints/./logs', histogram_freq=0, write_graph=True, write_images=False, embeddings_freq=0, embeddings_layer_names=None, embeddings_metadata=None))

    # Creates live data:
    # For better yield. The duration of the training is extended.

    # If you don't want, use this:
    # model.fit(X, Y, batch_size=batch_size, epochs=epochs, validation_data=(X_test, Y_test), shuffle=True, callbacks=checkpoints)

    generated_data = ImageDataGenerator(featurewise_center=False, samplewise_center=False, featurewise_std_normalization=False, samplewise_std_normalization=False, zca_whitening=False, rotation_range=0,  width_shift_range=0.1, height_shift_range=0.1, horizontal_flip = True, vertical_flip = False)
    generated_data.fit(X)

    model.fit_generator(generated_data.flow(X, Y, batch_size=batch_size), steps_per_epoch=X.shape[0]/6, epochs=epochs, validation_data=(X_test, Y_test), callbacks=checkpoints)

    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号