def find_knn(self, train_strings, train_labels, test_strings):
"""Find 3 nearest neighbors of each item in test_strings in
train_strings and report their labels as the prediction.
Args:
train_strings (ndarray): Numpy array with strings in training set
train_labels (ndarray): Numpy array with labels of train_strings
test_strings (ndarray): Numpy array with string to be predict for
"""
prediction = np.zeros((len(test_strings), self.num_classes))
for i in range(len(test_strings)):
a_str = test_strings[i]
dists = np.array([0] * len(train_strings))
for j in range(len(train_strings)):
b_str = train_strings[j]
dists[j] = lev.distance(a_str, b_str)
# finding the top 3
top3 = dists.argsort()[:3]
for ind in top3:
prediction[i][self.column_index[train_labels[ind]]] += 1.0 / 3
return prediction
评论列表
文章目录