def fit(self, X, y, sample_weight=None):
"""Fit (estimates) the centroids.
Parameters
----------
X : ndarray, shape (n_trials, n_channels, n_channels)
ndarray of SPD matrices.
y : ndarray shape (n_trials, 1)
labels corresponding to each trial.
sample_weight : None | ndarray shape (n_trials, 1)
the weights of each sample. if None, each sample is treated with
equal weights.
Returns
-------
self : MDM instance
The MDM instance.
"""
self.classes_ = numpy.unique(y)
self.covmeans_ = []
if sample_weight is None:
sample_weight = numpy.ones(X.shape[0])
if self.n_jobs == 1:
for l in self.classes_:
self.covmeans_.append(
mean_covariance(X[y == l], metric=self.metric_mean,
sample_weight=sample_weight[y == l]))
else:
self.covmeans_ = Parallel(n_jobs=self.n_jobs)(
delayed(mean_covariance)(X[y == l], metric=self.metric_mean,
sample_weight=sample_weight[y == l])
for l in self.classes_)
return self
classification.py 文件源码
python
阅读 26
收藏 0
点赞 0
评论 0
评论列表
文章目录