knnClassifier.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:flexmatcher 作者: biggorilla-gh 项目源码 文件源码
def find_knn(self, train_strings, train_labels, test_strings):
        """Find 3 nearest neighbors of each item in test_strings in
        train_strings and report their labels as the prediction.

        Args:
            train_strings (ndarray): Numpy array with strings in training set
            train_labels (ndarray): Numpy array with labels of train_strings
            test_strings (ndarray): Numpy array with string to be predict for
        """
        prediction = np.zeros((len(test_strings), self.num_classes))
        for i in range(len(test_strings)):
            a_str = test_strings[i]
            dists = np.array([0] * len(train_strings))
            for j in range(len(train_strings)):
                b_str = train_strings[j]
                dists[j] = lev.distance(a_str, b_str)
            # finding the top 3
            top3 = dists.argsort()[:3]
            for ind in top3:
                prediction[i][self.column_index[train_labels[ind]]] += 1.0 / 3
        return prediction
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号