def test_hyperparameter_searcher_with_fit_params(cls, kwargs):
X = np.arange(100).reshape(10, 10)
y = np.array([0] * 5 + [1] * 5)
clf = CheckingClassifier(expected_fit_params=['spam', 'eggs'])
pipe = Pipeline([('clf', clf)])
searcher = cls(pipe, {'clf__foo_param': [1, 2, 3]}, cv=2, **kwargs)
# The CheckingClassifer generates an assertion error if
# a parameter is missing or has length != len(X).
with pytest.raises(AssertionError) as exc:
searcher.fit(X, y, clf__spam=np.ones(10))
assert "Expected fit parameter(s) ['eggs'] not seen." in str(exc.value)
searcher.fit(X, y, clf__spam=np.ones(10), clf__eggs=np.zeros(10))
# Test with dask objects as parameters
searcher.fit(X, y, clf__spam=da.ones(10, chunks=2),
clf__eggs=dask.delayed(np.zeros(10)))
test_model_selection_sklearn.py 文件源码
python
阅读 28
收藏 0
点赞 0
评论 0
评论列表
文章目录