test_model_selection_sklearn.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:dask-searchcv 作者: dask 项目源码 文件源码
def test_random_search_cv_results():
    # Make a dataset with a lot of noise to get various kind of prediction
    # errors across CV folds and parameter settings
    X, y = make_classification(n_samples=200, n_features=100, n_informative=3,
                               random_state=0)

    # scipy.stats dists now supports `seed` but we still support scipy 0.12
    # which doesn't support the seed. Hence the assertions in the test for
    # random_search alone should not depend on randomization.
    n_splits = 3
    n_search_iter = 30
    params = dict(C=expon(scale=10), gamma=expon(scale=0.1))
    random_search = dcv.RandomizedSearchCV(SVC(), n_iter=n_search_iter,
                                           cv=n_splits, iid=False,
                                           param_distributions=params,
                                           return_train_score=True)
    random_search.fit(X, y)
    random_search_iid = dcv.RandomizedSearchCV(SVC(), n_iter=n_search_iter,
                                               cv=n_splits, iid=True,
                                               param_distributions=params,
                                               return_train_score=True)
    random_search_iid.fit(X, y)

    param_keys = ('param_C', 'param_gamma')
    score_keys = ('mean_test_score', 'mean_train_score',
                  'rank_test_score',
                  'split0_test_score', 'split1_test_score',
                  'split2_test_score',
                  'split0_train_score', 'split1_train_score',
                  'split2_train_score',
                  'std_test_score', 'std_train_score',
                  'mean_fit_time', 'std_fit_time',
                  'mean_score_time', 'std_score_time')
    n_cand = n_search_iter

    for search, iid in zip((random_search, random_search_iid), (False, True)):
        assert iid == search.iid
        cv_results = search.cv_results_
        # Check results structure
        check_cv_results_array_types(cv_results, param_keys, score_keys)
        check_cv_results_keys(cv_results, param_keys, score_keys, n_cand)
        # For random_search, all the param array vals should be unmasked
        assert not (any(cv_results['param_C'].mask) or
                    any(cv_results['param_gamma'].mask))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号