test_neighbors.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:Parallel-SGD 作者: angadgill 项目源码 文件源码
def test_k_and_radius_neighbors_duplicates():
    # Test behavior of kneighbors when duplicates are present in query

    for algorithm in ALGORITHMS:
        nn = neighbors.NearestNeighbors(n_neighbors=1, algorithm=algorithm)
        nn.fit([[0], [1]])

        # Do not do anything special to duplicates.
        kng = nn.kneighbors_graph([[0], [1]], mode='distance')
        assert_array_equal(
            kng.A,
            np.array([[0., 0.], [0., 0.]]))
        assert_array_equal(kng.data, [0., 0.])
        assert_array_equal(kng.indices, [0, 1])

        dist, ind = nn.radius_neighbors([[0], [1]], radius=1.5)
        check_object_arrays(dist, [[0, 1], [1, 0]])
        check_object_arrays(ind, [[0, 1], [0, 1]])

        rng = nn.radius_neighbors_graph([[0], [1]], radius=1.5)
        assert_array_equal(rng.A, np.ones((2, 2)))

        rng = nn.radius_neighbors_graph([[0], [1]], radius=1.5,
                                        mode='distance')
        assert_array_equal(rng.A, [[0, 1], [1, 0]])
        assert_array_equal(rng.indices, [0, 1, 0, 1])
        assert_array_equal(rng.data, [0, 1, 1, 0])

        # Mask the first duplicates when n_duplicates > n_neighbors.
        X = np.ones((3, 1))
        nn = neighbors.NearestNeighbors(n_neighbors=1)
        nn.fit(X)
        dist, ind = nn.kneighbors()
        assert_array_equal(dist, np.zeros((3, 1)))
        assert_array_equal(ind, [[1], [0], [1]])

        # Test that zeros are explicitly marked in kneighbors_graph.
        kng = nn.kneighbors_graph(mode='distance')
        assert_array_equal(
            kng.A, np.zeros((3, 3)))
        assert_array_equal(kng.data, np.zeros(3))
        assert_array_equal(kng.indices, [1., 0., 1.])
        assert_array_equal(
            nn.kneighbors_graph().A,
            np.array([[0., 1., 0.], [1., 0., 0.], [0., 1., 0.]]))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号