estimator_checks.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:Parallel-SGD 作者: angadgill 项目源码 文件源码
def check_clustering(name, Alg):
    X, y = make_blobs(n_samples=50, random_state=1)
    X, y = shuffle(X, y, random_state=7)
    X = StandardScaler().fit_transform(X)
    n_samples, n_features = X.shape
    # catch deprecation and neighbors warnings
    with warnings.catch_warnings(record=True):
        alg = Alg()
    set_testing_parameters(alg)
    if hasattr(alg, "n_clusters"):
        alg.set_params(n_clusters=3)
    set_random_state(alg)
    if name == 'AffinityPropagation':
        alg.set_params(preference=-100)
        alg.set_params(max_iter=100)

    # fit
    alg.fit(X)
    # with lists
    alg.fit(X.tolist())

    assert_equal(alg.labels_.shape, (n_samples,))
    pred = alg.labels_
    assert_greater(adjusted_rand_score(pred, y), 0.4)
    # fit another time with ``fit_predict`` and compare results
    if name is 'SpectralClustering':
        # there is no way to make Spectral clustering deterministic :(
        return
    set_random_state(alg)
    with warnings.catch_warnings(record=True):
        pred2 = alg.fit_predict(X)
    assert_array_equal(pred, pred2)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号