def test_k_fold_segment_split():
""" Test function for the k-fold segment split """
interictal_classes = np.zeros(120)
preictal_classes = np.ones(120)
classes = np.concatenate((interictal_classes, preictal_classes,))
segments = np.arange(12)
i = np.arange(240)
index = pd.MultiIndex.from_product([segments, np.arange(20)], names=('segment', 'start_sample'))
dataframe = pd.DataFrame({'Preictal': classes, 'i': i}, index=index)
# With a 6-fold cross validator, we expect each held-out fold to contain exactly 2 segments, one from each class
cv1 = SegmentCrossValidator(dataframe, n_folds=6, shuffle=True, random_state=42)
cv2 = SegmentCrossValidator(dataframe, n_folds=6, shuffle=True, random_state=42)
for (training_fold1, test_fold1), (training_fold2, test_fold2) in zip(cv1, cv2):
assert np.all(training_fold1 == training_fold1) and np.all(test_fold1 == test_fold2)
评论列表
文章目录