test_split.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:Parallel-SGD 作者: angadgill 项目源码 文件源码
def check_cv_coverage(cv, X, y, labels, expected_n_iter=None):
    n_samples = _num_samples(X)
    # Check that a all the samples appear at least once in a test fold
    if expected_n_iter is not None:
        assert_equal(cv.get_n_splits(X, y, labels), expected_n_iter)
    else:
        expected_n_iter = cv.get_n_splits(X, y, labels)

    collected_test_samples = set()
    iterations = 0
    for train, test in cv.split(X, y, labels):
        check_valid_split(train, test, n_samples=n_samples)
        iterations += 1
        collected_test_samples.update(test)

    # Check that the accumulated test samples cover the whole dataset
    assert_equal(iterations, expected_n_iter)
    if n_samples is not None:
        assert_equal(collected_test_samples, set(range(n_samples)))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号