def multi_label_train_test_split(y, test_size=0.2):
"""Creates a test split with roughly the same multi-label distribution in `y`.
Args:
y: The multi-label outputs.
test_size: The test size in [0, 1]
Returns:
The train and test indices.
"""
if test_size <= 0 or test_size >= 1:
raise ValueError("`test_size` should be between 0 and 1")
# Find the smallest rational number.
frac = Fraction(test_size).limit_denominator()
test_folds, total_folds = frac.numerator, frac.denominator
logger.warn('Inferring test_size as {}/{}. Generating {} folds. The algorithm might fail if denominator is large.'
.format(test_folds, total_folds, total_folds))
folds = equal_distribution_folds(y, folds=total_folds)
test_indices = np.concatenate(folds[:test_folds])
train_indices = np.concatenate(folds[test_folds:])
return train_indices, test_indices
评论列表
文章目录