test_split.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:Parallel-SGD 作者: angadgill 项目源码 文件源码
def test_stratified_shuffle_split_overlap_train_test_bug():
    # See https://github.com/scikit-learn/scikit-learn/issues/6121 for
    # the original bug report
    y = [0, 1, 2, 3] * 3 + [4, 5] * 5
    X = np.ones_like(y)

    splits = StratifiedShuffleSplit(n_iter=1,
                                    test_size=0.5, random_state=0)

    train, test = next(iter(splits.split(X=X, y=y)))

    assert_array_equal(np.intersect1d(train, test), [])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号