def check_cv(cv=3, y=None, classifier=False):
"""Dask aware version of ``sklearn.model_selection.check_cv``
Same as the scikit-learn version, but works if ``y`` is a dask object.
"""
if cv is None:
cv = 3
# If ``cv`` is not an integer, the scikit-learn implementation doesn't
# touch the ``y`` object, so passing on a dask object is fine
if not is_dask_collection(y) or not isinstance(cv, numbers.Integral):
return model_selection.check_cv(cv, y, classifier)
if classifier:
# ``y`` is a dask object. We need to compute the target type
target_type = delayed(type_of_target, pure=True)(y).compute()
if target_type in ('binary', 'multiclass'):
return StratifiedKFold(cv)
return KFold(cv)
python类check_cv()的实例源码
def _grid_search(self, train_X, train_y):
if callable(self.inner_cv):
# inner_cv = self.inner_cv(train_X, train_y)
inner_cv = self.inner_cv.split(train_X, train_y)
else:
# inner_cv = _check_cv(self.inner_cv, train_X, train_y,
# classifier=is_classifier(self.estimator))
inner_cv = _check_cv(self.inner_cv, train_y,
classifier=is_classifier(
self.estimator)).split(train_X, train_y)
master = MPIGridSearchCVMaster(self.param_grid, inner_cv,
self.estimator, self.scorer_,
self.fit_params)
return master.run(train_X, train_y)
def fit(self, X, y):
"""Fit the model to the training data."""
X, y = check_X_y(X, y, force_all_finite=False,
multi_output=self.multi_output)
_check_param_grid(self.param_grid)
# cv = _check_cv(self.cv, X, y, classifier=is_classifier(self.estimator))
cv = _check_cv(self.cv, y, classifier=is_classifier(self.estimator))
self.scorer_ = check_scoring(self.estimator, scoring=self.scoring)
if comm_rank == 0:
self._fit_master(X, y, cv)
else:
self._fit_slave()
return self
def _set_cv(cv, X, y, classifier):
"""This method returns either a `sklearn.cross_validation._PartitionIterator` or
`sklearn.model_selection.BaseCrossValidator` depending on whether sklearn-0.17
or sklearn-0.18 is being used.
Parameters
----------
cv : int, `_PartitionIterator` or `BaseCrossValidator`
The CV object or int to check. If an int, will be converted
into the appropriate class of crossvalidator.
X : pd.DataFrame or np.ndarray, shape(n_samples, n_features)
The dataframe or np.ndarray being fit in the grid search.
y : np.ndarray, shape(n_samples,)
The target being fit in the grid search.
classifier : bool
Whether the estimator being fit is a classifier
Returns
-------
`_PartitionIterator` or `BaseCrossValidator`
"""
return check_cv(cv, X, y, classifier) if not SK18 else check_cv(cv, y, classifier)
def our_check_cv(cv, X, y, classifier):
ret = base_check_cv(cv, y, classifier)
return ret.n_splits, list(ret.split(X, y=y))
def our_check_cv(cv, X, y, classifier):
ret = base_check_cv(cv, X, y, classifier)
return len(ret), list(iter(ret))
def _check_cv_non_float(self, y):
return check_cv(
self.cv,
y=y,
classifier=self.stratified,
)
def check_cv(self, y):
"""Resolve which cross validation strategy is used."""
y_arr = None
if self.stratified:
# Try to convert y to numpy for sklearn's check_cv; if conversion
# doesn't work, still try.
try:
y_arr = to_numpy(y)
except (AttributeError, TypeError):
y_arr = y
if self._is_float(self.cv):
return self._check_cv_float()
return self._check_cv_non_float(y_arr)
def __call__(self, X, y):
bad_y_error = ValueError("Stratified CV not possible with given y.")
if (y is None) and self.stratified:
raise bad_y_error
cv = self.check_cv(y)
if self.stratified and not self._is_stratified(cv):
raise bad_y_error
# pylint: disable=invalid-name
len_X = get_len(X)
if y is not None:
len_y = get_len(y)
if len_X != len_y:
raise ValueError("Cannot perform a CV split if X and y "
"have different lengths.")
args = (np.arange(len_X),)
if self._is_stratified(cv):
args = args + (to_numpy(y),)
idx_train, idx_valid = next(iter(cv.split(*args)))
X_train = multi_indexing(X, idx_train)
X_valid = multi_indexing(X, idx_valid)
y_train = None if y is None else multi_indexing(y, idx_train)
y_valid = None if y is None else multi_indexing(y, idx_valid)
return X_train, X_valid, y_train, y_valid
def test_check_cv():
X = np.ones(9)
cv = check_cv(3, classifier=False)
# Use numpy.testing.assert_equal which recursively compares
# lists of lists
np.testing.assert_equal(list(KFold(3).split(X)), list(cv.split(X)))
y_binary = np.array([0, 1, 0, 1, 0, 0, 1, 1, 1])
cv = check_cv(3, y_binary, classifier=True)
np.testing.assert_equal(list(StratifiedKFold(3).split(X, y_binary)),
list(cv.split(X, y_binary)))
y_multiclass = np.array([0, 1, 0, 1, 2, 1, 2, 0, 2])
cv = check_cv(3, y_multiclass, classifier=True)
np.testing.assert_equal(list(StratifiedKFold(3).split(X, y_multiclass)),
list(cv.split(X, y_multiclass)))
X = np.ones(5)
y_multilabel = np.array([[0, 0, 0, 0], [0, 1, 1, 0], [0, 0, 0, 1],
[1, 1, 0, 1], [0, 0, 1, 0]])
cv = check_cv(3, y_multilabel, classifier=True)
np.testing.assert_equal(list(KFold(3).split(X)), list(cv.split(X)))
y_multioutput = np.array([[1, 2], [0, 3], [0, 0], [3, 1], [2, 0]])
cv = check_cv(3, y_multioutput, classifier=True)
np.testing.assert_equal(list(KFold(3).split(X)), list(cv.split(X)))
# Check if the old style classes are wrapped to have a split method
X = np.ones(9)
y_multiclass = np.array([0, 1, 0, 1, 2, 1, 2, 0, 2])
cv1 = check_cv(3, y_multiclass, classifier=True)
with warnings.catch_warnings(record=True):
from sklearn.cross_validation import StratifiedKFold as OldSKF
cv2 = check_cv(OldSKF(y_multiclass, n_folds=3))
np.testing.assert_equal(list(cv1.split(X, y_multiclass)),
list(cv2.split()))
assert_raises(ValueError, check_cv, cv="lolo")