seizure_modeling.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:kaggle-seizure-prediction 作者: sics-lm 项目源码 文件源码
def select_model(training_data, method='logistic',
                 do_segment_split=True,
                 processes=1,
                 cv_verbosity=2,
                 model_params=None,
                 random_state=None):
    """
    Fits a model given by *method* to the training data.
    :param training_data: The training data to fit the model with
    :param method: A string which specifies the model to use.
    :param do_segment_split: If True, the training data will be split by segment.
    :param processes: The number of processes to use for the grid search.
    :param cv_verbosity: The verbosity level of the grid search. 0 is silent, 2 is maximum verbosity.
    :param model_params: An optional dictionary with keyword arguments to tune the grid search.
    :param random_state: A constant which will seed the random number generator if given.
    :return: The fitted grid search object.
    """

    logging.info("Training a {} model".format(method))

    training_data_x = training_data.drop('Preictal', axis=1)
    training_data_y = training_data['Preictal']

    cv = get_cv_generator(training_data, do_segment_split=do_segment_split, random_state=random_state)

    scorer = sklearn.metrics.make_scorer(sklearn.metrics.roc_auc_score, average='weighted')
    model_dict = get_model(method,
                           training_data_x,
                           training_data_y,
                           model_params=model_params,
                           random_state=random_state)
    common_cv_kwargs = dict(cv=cv,
                            scoring=scorer,
                            n_jobs=processes,
                            pre_dispatch='2*n_jobs',
                            refit=True,
                            verbose=cv_verbosity,
                            iid=False)

    cv_kwargs = dict(common_cv_kwargs)
    cv_kwargs.update(model_dict)

    logging.info("Running grid search using the parameters: {}".format(model_dict))
    clf = GridSearchCV(**cv_kwargs)
    clf.fit(training_data_x, training_data_y)

    return clf
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号