def query(self, query, k=None, indices=None):
if self._fit_X is None:
raise NotFittedError
q = super().transform([query])
if indices is not None:
fit_X = self._fit_X[indices]
else:
fit_X = self._fit_X
# both fit_X and q are l2-normalized
D = linear_kernel(q, fit_X)
ind = argtopk(D[0], k)
return ind
评论列表
文章目录