def set_kfold(self, no_folds = 10, fold_id = 0):
inst = KFold(n_splits = no_folds, shuffle=True, random_state=125)
self.fold_id = fold_id
self.KFolds = list(inst.split(np.arange(self.no_samples)))
self.train_idx, self.test_idx = self.KFolds[fold_id]
self.no_samples_train = self.train_idx.shape[0]
self.no_samples_test = self.test_idx.shape[0]
self.print_ext('Data ready. no_samples_train:', self.no_samples_train, 'no_samples_test:', self.no_samples_test)
if self.train_batch_size == 0:
self.train_batch_size = self.no_samples_train
if self.test_batch_size == 0:
self.test_batch_size = self.no_samples_test
self.train_batch_size = min(self.train_batch_size, self.no_samples_train)
self.test_batch_size = min(self.test_batch_size, self.no_samples_test)
# This function is cropped before batch
# Slice each sample to improve performance
评论列表
文章目录