def test_cache_cv():
X, y = make_classification(n_samples=100, n_features=10, random_state=0)
X2 = X.view(CountTakes)
gs = dcv.GridSearchCV(MockClassifier(), {'foo_param': [0, 1, 2]},
cv=3, cache_cv=False, scheduler='sync')
gs.fit(X2, y)
assert X2.count == 2 * 3 * 3 # (1 train + 1 test) * n_params * n_splits
X2 = X.view(CountTakes)
assert X2.count == 0
gs.cache_cv = True
gs.fit(X2, y)
assert X2.count == 2 * 3 # (1 test + 1 train) * n_splits
评论列表
文章目录