def test_make_grid_search():
X, y = load_iris(return_X_y=True)
lr = LogisticRegression()
svc = set_grid(SVC(), kernel=['poly'], degree=[2, 3])
gs1 = make_grid_search(lr, cv=5) # empty grid
gs2 = make_grid_search(svc, cv=5)
gs3 = make_grid_search([lr, svc], cv=5)
for gs, n_results in [(gs1, 1), (gs2, 2), (gs3, 3)]:
gs.fit(X, y)
assert gs.cv == 5
assert len(gs.cv_results_['params']) == n_results
svc_mask = gs3.cv_results_['param_root'] == svc
assert svc_mask.sum() == 2
assert gs3.cv_results_['param_root__degree'][svc_mask].tolist() == [2, 3]
assert gs3.cv_results_['param_root'][~svc_mask].tolist() == [lr]
评论列表
文章目录