def test_build_param_grid_set_estimator():
clf1 = SVC()
clf2 = LogisticRegression()
clf3 = SVC()
clf4 = SGDClassifier()
estimator = set_grid(Pipeline([('sel', set_grid(SelectKBest(), k=[2, 3])),
('clf', None)]),
clf=[set_grid(clf1, kernel=['linear']),
clf2,
set_grid(clf3, kernel=['poly'], degree=[2, 3]),
clf4])
param_grid = [{'clf': [clf1], 'clf__kernel': ['linear'], 'sel__k': [2, 3]},
{'clf': [clf3], 'clf__kernel': ['poly'],
'clf__degree': [2, 3], 'sel__k': [2, 3]},
{'clf': [clf2, clf4], 'sel__k': [2, 3]}]
assert build_param_grid(estimator) == param_grid
评论列表
文章目录