def test_check_cv():
X = np.ones(9)
cv = check_cv(3, classifier=False)
# Use numpy.testing.assert_equal which recursively compares
# lists of lists
np.testing.assert_equal(list(KFold(3).split(X)), list(cv.split(X)))
y_binary = np.array([0, 1, 0, 1, 0, 0, 1, 1, 1])
cv = check_cv(3, y_binary, classifier=True)
np.testing.assert_equal(list(StratifiedKFold(3).split(X, y_binary)),
list(cv.split(X, y_binary)))
y_multiclass = np.array([0, 1, 0, 1, 2, 1, 2, 0, 2])
cv = check_cv(3, y_multiclass, classifier=True)
np.testing.assert_equal(list(StratifiedKFold(3).split(X, y_multiclass)),
list(cv.split(X, y_multiclass)))
X = np.ones(5)
y_multilabel = np.array([[0, 0, 0, 0], [0, 1, 1, 0], [0, 0, 0, 1],
[1, 1, 0, 1], [0, 0, 1, 0]])
cv = check_cv(3, y_multilabel, classifier=True)
np.testing.assert_equal(list(KFold(3).split(X)), list(cv.split(X)))
y_multioutput = np.array([[1, 2], [0, 3], [0, 0], [3, 1], [2, 0]])
cv = check_cv(3, y_multioutput, classifier=True)
np.testing.assert_equal(list(KFold(3).split(X)), list(cv.split(X)))
# Check if the old style classes are wrapped to have a split method
X = np.ones(9)
y_multiclass = np.array([0, 1, 0, 1, 2, 1, 2, 0, 2])
cv1 = check_cv(3, y_multiclass, classifier=True)
with warnings.catch_warnings(record=True):
from sklearn.cross_validation import StratifiedKFold as OldSKF
cv2 = check_cv(OldSKF(y_multiclass, n_folds=3))
np.testing.assert_equal(list(cv1.split(X, y_multiclass)),
list(cv2.split()))
assert_raises(ValueError, check_cv, cv="lolo")
评论列表
文章目录