def test_callable_metric():
def custom_metric(x1, x2):
return np.sqrt(np.sum(x1 ** 2 + x2 ** 2))
X = np.random.RandomState(42).rand(20, 2)
nbrs1 = neighbors.NearestNeighbors(3, algorithm='auto',
metric=custom_metric)
nbrs2 = neighbors.NearestNeighbors(3, algorithm='brute',
metric=custom_metric)
nbrs1.fit(X)
nbrs2.fit(X)
dist1, ind1 = nbrs1.kneighbors(X)
dist2, ind2 = nbrs2.kneighbors(X)
assert_array_almost_equal(dist1, dist2)
评论列表
文章目录