test_model_selection_sklearn.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:dask-searchcv 作者: dask 项目源码 文件源码
def test_hyperparameter_searcher_with_fit_params(cls, kwargs):
    X = np.arange(100).reshape(10, 10)
    y = np.array([0] * 5 + [1] * 5)
    clf = CheckingClassifier(expected_fit_params=['spam', 'eggs'])
    pipe = Pipeline([('clf', clf)])
    searcher = cls(pipe, {'clf__foo_param': [1, 2, 3]}, cv=2, **kwargs)

    # The CheckingClassifer generates an assertion error if
    # a parameter is missing or has length != len(X).
    with pytest.raises(AssertionError) as exc:
        searcher.fit(X, y, clf__spam=np.ones(10))
    assert "Expected fit parameter(s) ['eggs'] not seen." in str(exc.value)

    searcher.fit(X, y, clf__spam=np.ones(10), clf__eggs=np.zeros(10))
    # Test with dask objects as parameters
    searcher.fit(X, y, clf__spam=da.ones(10, chunks=2),
                 clf__eggs=dask.delayed(np.zeros(10)))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号