classification.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:decoding-brain-challenge-2016 作者: alexandrebarachant 项目源码 文件源码
def predict(self, covtest):
        """get the predictions.

        Parameters
        ----------
        X : ndarray, shape (n_trials, n_channels, n_channels)
            ndarray of SPD matrices.

        Returns
        -------
        pred : ndarray of int, shape (n_trials, 1)
            the prediction for each trials according to the closest centroid.
        """
        dist = self._predict_distances(covtest)
        neighbors_classes = self.classes_[numpy.argsort(dist)]
        out, _ = stats.mode(neighbors_classes[:, 0:self.n_neighbors], axis=1)
        return out.ravel()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号