def _check_cv(cv=3, y=None, classifier=False, **kwargs):
"""Input checker utility for building a cross-validator.
Parameters
----------
cv : int, cross-validation generator or an iterable, optional
Determines the cross-validation splitting strategy.
Possible inputs for cv are:
- None, to use the default 3-fold cross-validation,
- integer, to specify the number of folds.
- An object to be used as a cross-validation generator.
- An iterable yielding train/test splits.
For integer/None inputs, if classifier is True and ``y`` is either
binary or multiclass, :class:`StratifiedKFold` is used. In all other
cases, :class:`KFold` is used.
Refer :ref:`User Guide <cross_validation>` for the various
cross-validation strategies that can be used here.
y : array-like, optional
The target variable for supervised learning problems.
classifier : boolean, optional, default False
Whether the task is a classification task, in which case
stratified KFold will be used.
kwargs : dict
Other parameters for StratifiedShuffleSplit or ShuffleSplit.
Returns
-------
checked_cv : a cross-validator instance.
The return value is a cross-validator which generates the train/test
splits via the ``split`` method.
"""
if cv is None:
cv = kwargs.pop('n_splits', 0) or 10
if isinstance(cv, numbers.Integral):
if (classifier and (y is not None) and
(type_of_target(y) in ('binary', 'multiclass'))):
return StratifiedShuffleSplit(cv, **kwargs)
else:
return ShuffleSplit(cv, **kwargs)
if not hasattr(cv, 'split') or isinstance(cv, str):
if not isinstance(cv, Iterable) or isinstance(cv, str):
raise ValueError("Expected cv as an integer, cross-validation "
"object (from sklearn.model_selection) "
"or an iterable. Got %s." % cv)
return _CVIterableWrapper(cv)
return cv # New style cv objects are passed without any modification
评论列表
文章目录