def fit(self, X, y):
"""Train xdawn spatial filters.
Parameters
----------
X : ndarray, shape (n_trials, n_channels, n_samples)
ndarray of trials.
y : ndarray shape (n_trials, 1)
labels corresponding to each trial.
Returns
-------
self : Xdawn instance
The Xdawn instance.
"""
Nt, Ne, Ns = X.shape
self.classes_ = (numpy.unique(y) if self.classes is None else
self.classes)
# FIXME : too many reshape operation
tmp = X.transpose((1, 2, 0))
Cx = numpy.matrix(self.estimator(tmp.reshape(Ne, Ns * Nt)))
self.evokeds_ = []
self.filters_ = []
self.patterns_ = []
for c in self.classes_:
# Prototyped responce for each class
P = numpy.mean(X[y == c, :, :], axis=0)
# Covariance matrix of the prototyper response & signal
C = numpy.matrix(self.estimator(P))
# Spatial filters
evals, evecs = eigh(C, Cx)
evecs = evecs[:, numpy.argsort(evals)[::-1]] # sort eigenvectors
evecs /= numpy.apply_along_axis(numpy.linalg.norm, 0, evecs)
V = evecs
A = numpy.linalg.pinv(V.T)
# create the reduced prototyped response
self.filters_.append(V[:, 0:self.nfilter].T)
self.patterns_.append(A[:, 0:self.nfilter].T)
self.evokeds_.append(numpy.dot(V[:, 0:self.nfilter].T, P))
self.evokeds_ = numpy.concatenate(self.evokeds_, axis=0)
self.filters_ = numpy.concatenate(self.filters_, axis=0)
self.patterns_ = numpy.concatenate(self.patterns_, axis=0)
return self
spatialfilters.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录