classification.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:decoding-brain-challenge-2016 作者: alexandrebarachant 项目源码 文件源码
def fit(self, X, y, sample_weight=None):
        """Fit (estimates) the centroids.

        Parameters
        ----------
        X : ndarray, shape (n_trials, n_channels, n_channels)
            ndarray of SPD matrices.
        y : ndarray shape (n_trials, 1)
            labels corresponding to each trial.
        sample_weight : None | ndarray shape (n_trials, 1)
            the weights of each sample. if None, each sample is treated with
            equal weights.

        Returns
        -------
        self : MDM instance
            The MDM instance.
        """
        self.classes_ = numpy.unique(y)

        self.covmeans_ = []

        if sample_weight is None:
            sample_weight = numpy.ones(X.shape[0])

        if self.n_jobs == 1:
            for l in self.classes_:
                self.covmeans_.append(
                    mean_covariance(X[y == l], metric=self.metric_mean,
                                    sample_weight=sample_weight[y == l]))
        else:
            self.covmeans_ = Parallel(n_jobs=self.n_jobs)(
                delayed(mean_covariance)(X[y == l], metric=self.metric_mean,
                                         sample_weight=sample_weight[y == l])
                for l in self.classes_)

        return self
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号