def select_model(training_data, method='logistic',
do_segment_split=True,
processes=1,
cv_verbosity=2,
model_params=None,
random_state=None):
"""
Fits a model given by *method* to the training data.
:param training_data: The training data to fit the model with
:param method: A string which specifies the model to use.
:param do_segment_split: If True, the training data will be split by segment.
:param processes: The number of processes to use for the grid search.
:param cv_verbosity: The verbosity level of the grid search. 0 is silent, 2 is maximum verbosity.
:param model_params: An optional dictionary with keyword arguments to tune the grid search.
:param random_state: A constant which will seed the random number generator if given.
:return: The fitted grid search object.
"""
logging.info("Training a {} model".format(method))
training_data_x = training_data.drop('Preictal', axis=1)
training_data_y = training_data['Preictal']
cv = get_cv_generator(training_data, do_segment_split=do_segment_split, random_state=random_state)
scorer = sklearn.metrics.make_scorer(sklearn.metrics.roc_auc_score, average='weighted')
model_dict = get_model(method,
training_data_x,
training_data_y,
model_params=model_params,
random_state=random_state)
common_cv_kwargs = dict(cv=cv,
scoring=scorer,
n_jobs=processes,
pre_dispatch='2*n_jobs',
refit=True,
verbose=cv_verbosity,
iid=False)
cv_kwargs = dict(common_cv_kwargs)
cv_kwargs.update(model_dict)
logging.info("Running grid search using the parameters: {}".format(model_dict))
clf = GridSearchCV(**cv_kwargs)
clf.fit(training_data_x, training_data_y)
return clf
seizure_modeling.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录