def validation_curve(estimator, epochs, y, param_name, param_range, cv=None):
"""Validation curve on epochs.
Parameters
----------
estimator : object that implements "fit" and "predict" method.
the estimator whose Validation curve must be found
epochs : instance of mne.Epochs.
The epochs.
y : array
The labels.
param_name : str
Name of the parameter that will be varied.
param_range : array
The values of the parameter that will be evaluated.
cv : int, cross-validation generator or an iterable, optional
Determines the cross-validation strategy.
Returns
-------
train_scores : array
The scores in the training set
test_scores : array
The scores in the test set
"""
from sklearn.model_selection import validation_curve
if not isinstance(estimator, GlobalAutoReject):
msg = 'No guarantee that it will work on this estimator.'
raise NotImplementedError(msg)
BaseEpochs = _get_epochs_type()
if not isinstance(epochs, BaseEpochs):
raise ValueError('Only accepts MNE epochs objects.')
data_picks = _handle_picks(epochs.info, picks=None)
X = epochs.get_data()[:, data_picks, :]
n_epochs, n_channels, n_times = X.shape
estimator.n_channels = n_channels
estimator.n_times = n_times
train_scores, test_scores = \
validation_curve(estimator, X.reshape(n_epochs, -1), y=y,
param_name="thresh", param_range=param_range,
cv=cv, n_jobs=1, verbose=0)
return train_scores, test_scores
评论列表
文章目录