improved_model.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:Convolution-neural-networks-made-easy-with-keras 作者: mingruimingrui 项目源码 文件源码
def train(model, X_train, y_train, X_test, y_test):
    sys.stdout.write('Training model with data augmentation\n\n')
    sys.stdout.flush()

    datagen = image_generator()
    datagen.fit(X_train)

    # train each iteration individually to back up current state
    # safety measure against potential crashes
    epoch_count = 0
    while epoch_count < epoch:
        epoch_count += 1
        sys.stdout.write('Epoch count: ' + str(epoch_count) + '\n')
        sys.stdout.flush()
        model.fit_generator(datagen.flow(X_train, y_train, batch_size=batch_size),
                            steps_per_epoch=len(X_train) // batch_size,
                            epochs=1,
                            validation_data=(X_test, y_test))
        sys.stdout.write('Epoch {} done, saving model to file\n\n'.format(epoch_count))
        sys.stdout.flush()
        model.save_weights('./models/convnet_improved_weights.h5')

    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号