def query(self, centroids):
if self.entity_neighbors is not None:
distances, indices = self.entity_neighbors.kneighbors(centroids)
return distances, indices
else:
pairwise_distances = scipy.spatial.distance.cdist(
centroids, self.entity_representations,
metric=self.entity_representation_distance)
distances = np.sort(pairwise_distances, axis=1)
indices = pairwise_distances.argsort(axis=1)\
.argsort(axis=1).argsort(axis=1)
return distances, indices
评论列表
文章目录