def train(model, X_train, y_train, X_test, y_test):
sys.stdout.write('Training model with data augmentation\n\n')
sys.stdout.flush()
datagen = image_generator()
datagen.fit(X_train)
# train each iteration individually to back up current state
# safety measure against potential crashes
epoch_count = 0
while epoch_count < epoch:
epoch_count += 1
sys.stdout.write('Epoch count: ' + str(epoch_count) + '\n')
sys.stdout.flush()
model.fit_generator(datagen.flow(X_train, y_train, batch_size=batch_size),
steps_per_epoch=len(X_train) // batch_size,
epochs=1,
validation_data=(X_test, y_test))
sys.stdout.write('Epoch {} done, saving model to file\n\n'.format(epoch_count))
sys.stdout.flush()
model.save_weights('./models/convnet_improved_weights.h5')
return model
improved_model.py 文件源码
python
阅读 34
收藏 0
点赞 0
评论 0
评论列表
文章目录