def test_grid_search_labels():
# Check if ValueError (when labels is None) propagates to GridSearchCV
# And also check if labels is correctly passed to the cv object
rng = np.random.RandomState(0)
X, y = make_classification(n_samples=15, n_classes=2, random_state=0)
labels = rng.randint(0, 3, 15)
clf = LinearSVC(random_state=0)
grid = {'C': [1]}
label_cvs = [LeaveOneLabelOut(), LeavePLabelOut(2), LabelKFold(),
LabelShuffleSplit()]
for cv in label_cvs:
gs = GridSearchCV(clf, grid, cv=cv)
assert_raise_message(ValueError,
"The labels parameter should not be None",
gs.fit, X, y)
gs.fit(X, y, labels)
non_label_cvs = [StratifiedKFold(), StratifiedShuffleSplit()]
for cv in non_label_cvs:
gs = GridSearchCV(clf, grid, cv=cv)
# Should not raise an error
gs.fit(X, y)
评论列表
文章目录