test_model_selection.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:dask-searchcv 作者: dask 项目源码 文件源码
def test_grid_search_dask_inputs():
    # Numpy versions
    np_X, np_y = make_classification(n_samples=15, n_classes=2, random_state=0)
    np_groups = np.random.RandomState(0).randint(0, 3, 15)
    # Dask array versions
    da_X = da.from_array(np_X, chunks=5)
    da_y = da.from_array(np_y, chunks=5)
    da_groups = da.from_array(np_groups, chunks=5)
    # Delayed versions
    del_X = delayed(np_X)
    del_y = delayed(np_y)
    del_groups = delayed(np_groups)

    cv = GroupKFold()
    clf = SVC(random_state=0)
    grid = {'C': [1]}

    sol = SVC(C=1, random_state=0).fit(np_X, np_y).support_vectors_

    for X, y, groups in product([np_X, da_X, del_X],
                                [np_y, da_y, del_y],
                                [np_groups, da_groups, del_groups]):
        gs = dcv.GridSearchCV(clf, grid, cv=cv)

        with pytest.raises(ValueError) as exc:
            gs.fit(X, y)
        assert "parameter should not be None" in str(exc.value)

        gs.fit(X, y, groups=groups)
        np.testing.assert_allclose(sol, gs.best_estimator_.support_vectors_)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号