def getFolds(labels, number_folds):
"""
Provides train/test indices to split data in train test sets.
Parameters
----------
labels: array-like of shape = [number_samples]
The target values (class labels in classification).
number_folds: int
The amount of folds for the k-fold cross-validation.
Return
----------
folds: StratifiedKFold
the train/test indices of the splitted data.
"""
return StratifiedKFold(y=labels, n_folds=number_folds, shuffle=True)
评论列表
文章目录