def _load_cv_data(self, list_files):
"""Load training and cross-validation sets."""
# Split files for training and validation sets
val_files = np.array_split(list_files, self.n_folds)
train_files = np.setdiff1d(list_files, val_files[self.fold_idx])
# Load a npz file
print "Load training set:"
data_train, label_train = self._load_npz_list_files(train_files)
print " "
print "Load validation set:"
data_val, label_val = self._load_npz_list_files(val_files[self.fold_idx])
print " "
# Reshape the data to match the input of the model - conv2d
data_train = np.squeeze(data_train)
data_val = np.squeeze(data_val)
data_train = data_train[:, :, np.newaxis, np.newaxis]
data_val = data_val[:, :, np.newaxis, np.newaxis]
# Casting
data_train = data_train.astype(np.float32)
label_train = label_train.astype(np.int32)
data_val = data_val.astype(np.float32)
label_val = label_val.astype(np.int32)
return data_train, label_train, data_val, label_val
评论列表
文章目录