def test_shuffle_kfold():
# Check the indices are shuffled properly
kf = KFold(3)
kf2 = KFold(3, shuffle=True, random_state=0)
kf3 = KFold(3, shuffle=True, random_state=1)
X = np.ones(300)
all_folds = np.zeros(300)
for (tr1, te1), (tr2, te2), (tr3, te3) in zip(
kf.split(X), kf2.split(X), kf3.split(X)):
for tr_a, tr_b in combinations((tr1, tr2, tr3), 2):
# Assert that there is no complete overlap
assert_not_equal(len(np.intersect1d(tr_a, tr_b)), len(tr1))
# Set all test indices in successive iterations of kf2 to 1
all_folds[te2] = 1
# Check that all indices are returned in the different test folds
assert_equal(sum(all_folds), 300)
评论列表
文章目录