def test_check_cv_return_types():
X = np.ones((9, 2))
cv = cval.check_cv(3, X, classifier=False)
assert_true(isinstance(cv, cval.KFold))
y_binary = np.array([0, 1, 0, 1, 0, 0, 1, 1, 1])
cv = cval.check_cv(3, X, y_binary, classifier=True)
assert_true(isinstance(cv, cval.StratifiedKFold))
y_multiclass = np.array([0, 1, 0, 1, 2, 1, 2, 0, 2])
cv = cval.check_cv(3, X, y_multiclass, classifier=True)
assert_true(isinstance(cv, cval.StratifiedKFold))
X = np.ones((5, 2))
y_multilabel = [[1, 0, 1], [1, 1, 0], [0, 0, 0], [0, 1, 1], [1, 0, 0]]
cv = cval.check_cv(3, X, y_multilabel, classifier=True)
assert_true(isinstance(cv, cval.KFold))
y_multioutput = np.array([[1, 2], [0, 3], [0, 0], [3, 1], [2, 0]])
cv = cval.check_cv(3, X, y_multioutput, classifier=True)
assert_true(isinstance(cv, cval.KFold))
评论列表
文章目录