def test_stratified_shuffle_split_overlap_train_test_bug():
# See https://github.com/scikit-learn/scikit-learn/issues/6121 for
# the original bug report
y = [0, 1, 2, 3] * 3 + [4, 5] * 5
X = np.ones_like(y)
splits = StratifiedShuffleSplit(n_iter=1,
test_size=0.5, random_state=0)
train, test = next(iter(splits.split(X=X, y=y)))
assert_array_equal(np.intersect1d(train, test), [])
评论列表
文章目录