test_search.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:Parallel-SGD 作者: angadgill 项目源码 文件源码
def test_randomized_search_grid_scores():
    # Make a dataset with a lot of noise to get various kind of prediction
    # errors across CV folds and parameter settings
    X, y = make_classification(n_samples=200, n_features=100, n_informative=3,
                               random_state=0)

    # XXX: as of today (scipy 0.12) it's not possible to set the random seed
    # of scipy.stats distributions: the assertions in this test should thus
    # not depend on the randomization
    params = dict(C=expon(scale=10),
                  gamma=expon(scale=0.1))
    n_cv_iter = 3
    n_search_iter = 30
    search = RandomizedSearchCV(SVC(), n_iter=n_search_iter, cv=n_cv_iter,
                                param_distributions=params, iid=False)
    search.fit(X, y)
    assert_equal(len(search.grid_scores_), n_search_iter)

    # Check consistency of the structure of each cv_score item
    for cv_score in search.grid_scores_:
        assert_equal(len(cv_score.cv_validation_scores), n_cv_iter)
        # Because we set iid to False, the mean_validation score is the
        # mean of the fold mean scores instead of the aggregate sample-wise
        # mean score
        assert_almost_equal(np.mean(cv_score.cv_validation_scores),
                            cv_score.mean_validation_score)
        assert_equal(list(sorted(cv_score.parameters.keys())),
                     list(sorted(params.keys())))

    # Check the consistency with the best_score_ and best_params_ attributes
    sorted_grid_scores = list(sorted(search.grid_scores_,
                              key=lambda x: x.mean_validation_score))
    best_score = sorted_grid_scores[-1].mean_validation_score
    assert_equal(search.best_score_, best_score)

    tied_best_params = [s.parameters for s in sorted_grid_scores
                        if s.mean_validation_score == best_score]
    assert_true(search.best_params_ in tied_best_params,
                "best_params_={0} is not part of the"
                " tied best models: {1}".format(
                    search.best_params_, tied_best_params))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号