lmnn.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:pylmnn 作者: johny-c 项目源码 文件源码
def _select_target_neighbors(self):
        """Find the target neighbors of each sample, that stay fixed during training.

        Returns
        -------
        array_like
            An array of neighbors indices for each sample with shape (n_samples, n_neighbors).

        """

        self.logger.info('Finding target neighbors...')
        target_neighbors = np.empty((self.X_.shape[0], self.n_neighbors_), dtype=int)
        for class_ in self.classes_:
            class_ind, = np.where(np.equal(self.y_, class_))
            dist = euclidean_distances(self.X_[class_ind], squared=True)
            np.fill_diagonal(dist, np.inf)
            neigh_ind = np.argpartition(dist, self.n_neighbors_ - 1, axis=1)
            neigh_ind = neigh_ind[:, :self.n_neighbors_]
            # argpartition doesn't guarantee sorted order, so we sort again but only the k neighbors
            row_ind = np.arange(len(class_ind))[:, None]
            neigh_ind = neigh_ind[row_ind, np.argsort(dist[row_ind, neigh_ind])]
            target_neighbors[class_ind] = class_ind[neigh_ind]

        return target_neighbors
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号