def check_cv(cv=3, y=None, classifier=False):
"""Dask aware version of ``sklearn.model_selection.check_cv``
Same as the scikit-learn version, but works if ``y`` is a dask object.
"""
if cv is None:
cv = 3
# If ``cv`` is not an integer, the scikit-learn implementation doesn't
# touch the ``y`` object, so passing on a dask object is fine
if not is_dask_collection(y) or not isinstance(cv, numbers.Integral):
return model_selection.check_cv(cv, y, classifier)
if classifier:
# ``y`` is a dask object. We need to compute the target type
target_type = delayed(type_of_target, pure=True)(y).compute()
if target_type in ('binary', 'multiclass'):
return StratifiedKFold(cv)
return KFold(cv)
评论列表
文章目录