test_birch.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:FreeDiscovery 作者: FreeDiscovery 项目源码 文件源码
def test_birch_hierarchy():
    X, y = make_blobs(random_state=40)
    brc = Birch(n_clusters=None, branching_factor=5,
                compute_sample_indices=True)
    brc.fit(X)

    # make sure that leave nodes contain all the samples
    n_leaves = 1
    sample_id = []
    current_leaf = brc.dummy_leaf_.next_leaf_
    while current_leaf:
        subclusters = current_leaf.subclusters_
        for sc in subclusters:
            assert sc.n_samples_ == len(sc.samples_id_)
            sample_id += sc.samples_id_
        current_leaf = current_leaf.next_leaf_
        n_leaves += 1
    assert_array_equal(np.sort(sample_id), np.arange(X.shape[0]))

    # Verify that the resulting hierarchical tree is deeper than 1 level
    # (i.e. subclusters of the root node are nor tree leaves )
    assert len(brc.root_.subclusters_) < n_leaves

    # Make sure that subclusters of the root_ node contain all the samples
    sample_id = []
    for sc in brc.root_.subclusters_:
        sample_id += sc.samples_id_
        assert sc.n_samples_ == len(sc.samples_id_)
    assert_array_equal(np.sort(sample_id), np.arange(X.shape[0]))

    # Pick a sample at random and make sure that reported samples_id_
    # matches with the subcluster the sample is closest to
    document_id = 45
    document_in_subcluster = []
    distance_to_centroid = []
    for sc in brc.root_.subclusters_:
        centroid = X[sc.samples_id_, :].mean(axis=0)
        distance_to_centroid.append(((X[[document_id]] - centroid)**2).sum())
        document_in_subcluster.append(document_id in sc.samples_id_)

    assert np.argmin(distance_to_centroid) == \
        np.nonzero(document_in_subcluster)[0][0]

    # Make sure that we can recompute labels from tree leaves
    labels2 = np.zeros(X.shape[0], dtype=int)
    cluster_id = 0
    for current_leaf in brc._get_leaves():
        subclusters = current_leaf.subclusters_
        for sc in subclusters:
            labels2[list(sc.samples_id_)] = cluster_id
            cluster_id += 1

    assert np.unique(brc.labels_).shape == np.unique(labels2).shape
    # The two methods yield approximately equal labels
    assert v_measure_score(brc.labels_, labels2) > 0.95
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号