def test_non_euclidean_kneighbors():
rng = np.random.RandomState(0)
X = rng.rand(5, 5)
# Find a reasonable radius.
dist_array = pairwise_distances(X).flatten()
np.sort(dist_array)
radius = dist_array[15]
# Test kneighbors_graph
for metric in ['manhattan', 'chebyshev']:
nbrs_graph = neighbors.kneighbors_graph(
X, 3, metric=metric, mode='connectivity',
include_self=True).toarray()
nbrs1 = neighbors.NearestNeighbors(3, metric=metric).fit(X)
assert_array_equal(nbrs_graph, nbrs1.kneighbors_graph(X).toarray())
# Test radiusneighbors_graph
for metric in ['manhattan', 'chebyshev']:
nbrs_graph = neighbors.radius_neighbors_graph(
X, radius, metric=metric, mode='connectivity',
include_self=True).toarray()
nbrs1 = neighbors.NearestNeighbors(metric=metric, radius=radius).fit(X)
assert_array_equal(nbrs_graph, nbrs1.radius_neighbors_graph(X).A)
# Raise error when wrong parameters are supplied,
X_nbrs = neighbors.NearestNeighbors(3, metric='manhattan')
X_nbrs.fit(X)
assert_raises(ValueError, neighbors.kneighbors_graph, X_nbrs, 3,
metric='euclidean')
X_nbrs = neighbors.NearestNeighbors(radius=radius, metric='manhattan')
X_nbrs.fit(X)
assert_raises(ValueError, neighbors.radius_neighbors_graph, X_nbrs,
radius, metric='euclidean')
评论列表
文章目录