test_searchgrid.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:searchgrid 作者: jnothman 项目源码 文件源码
def test_make_grid_search():
    X, y = load_iris(return_X_y=True)
    lr = LogisticRegression()
    svc = set_grid(SVC(), kernel=['poly'], degree=[2, 3])
    gs1 = make_grid_search(lr, cv=5)  # empty grid
    gs2 = make_grid_search(svc, cv=5)
    gs3 = make_grid_search([lr, svc], cv=5)
    for gs, n_results in [(gs1, 1), (gs2, 2), (gs3, 3)]:
        gs.fit(X, y)
        assert gs.cv == 5
        assert len(gs.cv_results_['params']) == n_results

    svc_mask = gs3.cv_results_['param_root'] == svc
    assert svc_mask.sum() == 2
    assert gs3.cv_results_['param_root__degree'][svc_mask].tolist() == [2, 3]
    assert gs3.cv_results_['param_root'][~svc_mask].tolist() == [lr]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号