knn.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:cebl 作者: idfah 项目源码 文件源码
def probs(self, x):
        dists = np.hstack([self.distFunc(x, cls) for cls in self.trainData])
        indices = np.argpartition(dists, self.k, axis=1)[:,:self.k]

        #start = 0
        #votes = list()
        #for cls in self.trainData:
        #    end = start + cls.shape[0]
        #    votes.append(np.sum(np.logical_and(start <= indices, indices < end), axis=1))
        #    start = end

        ends = np.cumsum([len(cls) for cls in self.trainData])
        starts = ends - np.array([len(cls) for cls in self.trainData])
        votes = [np.sum(np.logical_and(start <= indices, indices < end), axis=1)
                 for start, end in zip(starts, ends)]
        votes = np.vstack(votes).T

        #probs = np.zeros((x.shape[0], self.nCls))
        #probs[np.arange(probs.shape[0]), np.argmax(votes, axis=1)] = 1.0
        ##probs = util.softmax(votes / float(self.k))
        probs = votes / float(self.k)

        return probs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号