def get_cv_generator(training_data, do_segment_split=True, random_state=None):
"""
Returns a cross validation generator.
:param training_data: The training data to create the folds from.
:param do_segment_split: If True, the folds will be generated based on the segment names.
:param random_state: A constant to use as a random seed.
:return: A generator which can be used by the grid search to generate cross validation folds.
"""
k_fold_kwargs = dict(n_folds=10, random_state=random_state)
if do_segment_split:
cv = dataset.SegmentCrossValidator(training_data, cross_validation.StratifiedKFold, **k_fold_kwargs)
else:
cv = sklearn.cross_validation.StratifiedKFold(training_data['Preictal'], **k_fold_kwargs)
return cv
seizure_modeling.py 文件源码
python
阅读 19
收藏 0
点赞 0
评论 0
评论列表
文章目录