test_search.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:Parallel-SGD 作者: angadgill 项目源码 文件源码
def test_grid_search_with_multioutput_data():
    # Test search with multi-output estimator

    X, y = make_multilabel_classification(return_indicator=True,
                                          random_state=0)

    est_parameters = {"max_depth": [1, 2, 3, 4]}
    cv = KFold(random_state=0)

    estimators = [DecisionTreeRegressor(random_state=0),
                  DecisionTreeClassifier(random_state=0)]

    # Test with grid search cv
    for est in estimators:
        grid_search = GridSearchCV(est, est_parameters, cv=cv)
        grid_search.fit(X, y)
        for parameters, _, cv_validation_scores in grid_search.grid_scores_:
            est.set_params(**parameters)

            for i, (train, test) in enumerate(cv.split(X, y)):
                est.fit(X[train], y[train])
                correct_score = est.score(X[test], y[test])
                assert_almost_equal(correct_score,
                                    cv_validation_scores[i])

    # Test with a randomized search
    for est in estimators:
        random_search = RandomizedSearchCV(est, est_parameters,
                                           cv=cv, n_iter=3)
        random_search.fit(X, y)
        for parameters, _, cv_validation_scores in random_search.grid_scores_:
            est.set_params(**parameters)

            for i, (train, test) in enumerate(cv.split(X, y)):
                est.fit(X[train], y[train])
                correct_score = est.score(X[test], y[test])
                assert_almost_equal(correct_score,
                                    cv_validation_scores[i])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号