neighbors.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:tslearn 作者: rtavenar 项目源码 文件源码
def predict(self, X):
        """Predict the class labels for the provided data

        Parameters
        ----------
        X : array-like, shape (n_ts, sz, d)
            Test samples.
        """
        X_ = to_time_series_dataset(X)
        neigh_dist, neigh_ind = self.kneighbors(X_)

        weights = _get_weights(neigh_dist, self.weights)

        if weights is None:
            mode, _ = stats.mode(self._fit_y[neigh_ind], axis=1)
        else:
            mode, _ = weighted_mode(self._fit_y[neigh_ind], weights, axis=1)

        return mode[:, 0]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号