def __call__(self, X, y):
"""
given a dataset X,y we split it, in order to do cross validation,
according to the procedure explained below:
if n_folds is not None, then we do cross validation
based on stratified folds
if n_class_samples is not None, then we do cross validation
using only <n_class_samples> training samples per class
if n_test_samples is not None, then we do cross validation
using only <n_test_samples> cross validaition samples per class
assumes that each datapoint is in a column of X
"""
n_classes = len(set(y))
if self.n_folds is not None:
# generate the folds
self.folds = StratifiedKFold(y, n_folds=self.n_folds,
shuffle=False, random_state=None)
elif self.n_class_samples is not None:
self.folds = []
for i in range(self.n_tests):
if type(self.n_class_samples) is not list:
self.n_class_samples = (np.ones(n_classes) * self.n_class_samples).astype(int)
if self.n_test_samples is not None:
self.n_test_samples = (np.ones(n_classes) * self.n_test_samples).astype(int)
data_idx = split_dataset(self.n_class_samples, self.n_test_samples, y)
train_idx = data_idx[0]
test_idx = data_idx[1]
self.folds.append((train_idx, test_idx))
self.cross_validate(X, y)
评论列表
文章目录