def check_cv_coverage(cv, X, y, labels, expected_n_iter=None):
n_samples = _num_samples(X)
# Check that a all the samples appear at least once in a test fold
if expected_n_iter is not None:
assert_equal(cv.get_n_splits(X, y, labels), expected_n_iter)
else:
expected_n_iter = cv.get_n_splits(X, y, labels)
collected_test_samples = set()
iterations = 0
for train, test in cv.split(X, y, labels):
check_valid_split(train, test, n_samples=n_samples)
iterations += 1
collected_test_samples.update(test)
# Check that the accumulated test samples cover the whole dataset
assert_equal(iterations, expected_n_iter)
if n_samples is not None:
assert_equal(collected_test_samples, set(range(n_samples)))
评论列表
文章目录