def test_stratified_shuffle_split_even():
# Test the StratifiedShuffleSplit, indices are drawn with a
# equal chance
n_folds = 5
n_iter = 1000
def assert_counts_are_ok(idx_counts, p):
# Here we test that the distribution of the counts
# per index is close enough to a binomial
threshold = 0.05 / n_splits
bf = stats.binom(n_splits, p)
for count in idx_counts:
p = bf.pmf(count)
assert_true(p > threshold,
"An index is not drawn with chance corresponding "
"to even draws")
for n_samples in (6, 22):
labels = np.array((n_samples // 2) * [0, 1])
splits = StratifiedShuffleSplit(n_iter=n_iter,
test_size=1. / n_folds,
random_state=0)
train_counts = [0] * n_samples
test_counts = [0] * n_samples
n_splits = 0
for train, test in splits.split(X=np.ones(n_samples), y=labels):
n_splits += 1
for counter, ids in [(train_counts, train), (test_counts, test)]:
for id in ids:
counter[id] += 1
assert_equal(n_splits, n_iter)
n_train, n_test = _validate_shuffle_split(n_samples,
test_size=1./n_folds,
train_size=1.-(1./n_folds))
assert_equal(len(train), n_train)
assert_equal(len(test), n_test)
assert_equal(len(set(train).intersection(test)), 0)
label_counts = np.unique(labels)
assert_equal(splits.test_size, 1.0 / n_folds)
assert_equal(n_train + n_test, len(labels))
assert_equal(len(label_counts), 2)
ex_test_p = float(n_test) / n_samples
ex_train_p = float(n_train) / n_samples
assert_counts_are_ok(train_counts, ex_train_p)
assert_counts_are_ok(test_counts, ex_test_p)
评论列表
文章目录