def create_new_stacked_ensemble():
path = functions.get_path_from_query_string(request)
req_body = request.get_json()
with functions.DBContextManager(path) as session:
if request.method == 'GET':
return jsonify(
list(map(lambda x: x.serialize, session.query(models.StackedEnsemble).all()))
)
if request.method == 'POST':
base_learners = session.query(models.BaseLearner).\
filter(models.BaseLearner.id.in_(req_body['base_learner_ids'])).all()
if len(base_learners) != len(req_body['base_learner_ids']):
raise exceptions.UserError('Not all base learners found')
for learner in base_learners:
if learner.job_status != 'finished':
raise exceptions.UserError('Not all base learners have finished')
base_learner_origin = session.query(models.BaseLearnerOrigin).\
filter_by(id=req_body['base_learner_origin_id']).first()
if base_learner_origin is None:
raise exceptions.UserError('Base learner origin {} not '
'found'.format(req_body['base_learner_origin_id']), 404)
# Retrieve full hyperparameters
est = base_learner_origin.return_estimator()
params = functions.import_object_from_string_code\
(req_body['secondary_learner_hyperparameters_source'], 'params')
est.set_params(**params)
hyperparameters = functions.make_serializable(est.get_params())
stacked_ensembles = session.query(models.StackedEnsemble).\
filter_by(base_learner_origin_id=req_body['base_learner_origin_id'],
secondary_learner_hyperparameters=hyperparameters,
base_learner_ids=sorted([bl.id for bl in base_learners])).all()
if stacked_ensembles:
raise exceptions.UserError('Stacked ensemble exists')
stacked_ensemble = models.StackedEnsemble(
secondary_learner_hyperparameters=hyperparameters,
base_learners=base_learners,
base_learner_origin=base_learner_origin,
job_status='queued'
)
session.add(stacked_ensemble)
session.commit()
with Connection(get_redis_connection()):
rqtasks.evaluate_stacked_ensemble.delay(path, stacked_ensemble.id)
return jsonify(stacked_ensemble.serialize)
评论列表
文章目录