def __call__(self, X, y):
bad_y_error = ValueError("Stratified CV not possible with given y.")
if (y is None) and self.stratified:
raise bad_y_error
cv = self.check_cv(y)
if self.stratified and not self._is_stratified(cv):
raise bad_y_error
# pylint: disable=invalid-name
len_X = get_len(X)
if y is not None:
len_y = get_len(y)
if len_X != len_y:
raise ValueError("Cannot perform a CV split if X and y "
"have different lengths.")
args = (np.arange(len_X),)
if self._is_stratified(cv):
args = args + (to_numpy(y),)
idx_train, idx_valid = next(iter(cv.split(*args)))
X_train = multi_indexing(X, idx_train)
X_valid = multi_indexing(X, idx_valid)
y_train = None if y is None else multi_indexing(y, idx_train)
y_valid = None if y is None else multi_indexing(y, idx_valid)
return X_train, X_valid, y_train, y_valid
评论列表
文章目录