def kfold(self, k=5, stratify=False, shuffle=True, seed=33):
"""K-Folds cross validation iterator.
Parameters
----------
k : int, default 5
stratify : bool, default False
shuffle : bool, default True
seed : int, default 33
Yields
-------
X_train, y_train, X_test, y_test, train_index, test_index
"""
if stratify:
kf = StratifiedKFold(n_splits=k, random_state=seed, shuffle=shuffle)
else:
kf = KFold(n_splits=k, random_state=seed, shuffle=shuffle)
for train_index, test_index in kf.split(self.X_train, self.y_train):
X_train, y_train = idx(self.X_train, train_index), self.y_train[train_index]
X_test, y_test = idx(self.X_train, test_index), self.y_train[test_index]
yield X_train, y_train, X_test, y_test, train_index, test_index
评论列表
文章目录