_validation.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:mriqc 作者: poldracklab 项目源码 文件源码
def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None,
                    n_jobs=1, verbose=0, fit_params=None,
                    pre_dispatch='2*n_jobs'):
    """
    Evaluate a score by cross-validation
    """
    if not isinstance(scoring, (list, tuple)):
        scoring = [scoring]

    X, y, groups = indexable(X, y, groups)

    cv = check_cv(cv, y, classifier=is_classifier(estimator))
    splits = list(cv.split(X, y, groups))
    scorer = [check_scoring(estimator, scoring=s) for s in scoring]
    # We clone the estimator to make sure that all the folds are
    # independent, and that it is pickle-able.
    parallel = Parallel(n_jobs=n_jobs, verbose=verbose,
                        pre_dispatch=pre_dispatch)
    scores = parallel(delayed(_fit_and_score)(clone(estimator), X, y, scorer,
                                              train, test, verbose, None,
                                              fit_params)
                      for train, test in splits)

    group_order = []
    if hasattr(cv, 'groups'):
        group_order = [np.array(cv.groups)[test].tolist()[0] for _, test in splits]
    return np.squeeze(np.array(scores)), group_order
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号