def train(model, X_train, y_train, X_test, y_test):
sys.stdout.write('Training model\n\n')
sys.stdout.flush()
# train each iteration individually to back up current state
# safety measure against potential crashes
epoch_count = 0
while epoch_count < epoch:
epoch_count += 1
sys.stdout.write('Epoch count: ' + str(epoch_count) + '\n')
sys.stdout.flush()
model.fit(X_train, y_train, batch_size=batch_size,
nb_epoch=1, validation_data=(X_test, y_test))
sys.stdout.write('Epoch {} done, saving model to file\n\n'.format(epoch_count))
sys.stdout.flush()
model.save_weights('./models/convnet_weights.h5')
return model
basic_model.py 文件源码
python
阅读 33
收藏 0
点赞 0
评论 0
评论列表
文章目录