def test_stratified_shuffle_split_init():
X = np.arange(7)
y = np.asarray([0, 1, 1, 1, 2, 2, 2])
# Check that error is raised if there is a class with only one sample
assert_raises(ValueError, next,
StratifiedShuffleSplit(3, 0.2).split(X, y))
# Check that error is raised if the test set size is smaller than n_classes
assert_raises(ValueError, next, StratifiedShuffleSplit(3, 2).split(X, y))
# Check that error is raised if the train set size is smaller than
# n_classes
assert_raises(ValueError, next,
StratifiedShuffleSplit(3, 3, 2).split(X, y))
X = np.arange(9)
y = np.asarray([0, 0, 0, 1, 1, 1, 2, 2, 2])
# Check that errors are raised if there is not enough samples
assert_raises(ValueError, StratifiedShuffleSplit, 3, 0.5, 0.6)
assert_raises(ValueError, next,
StratifiedShuffleSplit(3, 8, 0.6).split(X, y))
assert_raises(ValueError, next,
StratifiedShuffleSplit(3, 0.6, 8).split(X, y))
# Train size or test size too small
assert_raises(ValueError, next,
StratifiedShuffleSplit(train_size=2).split(X, y))
assert_raises(ValueError, next,
StratifiedShuffleSplit(test_size=2).split(X, y))
评论列表
文章目录