k_means.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:dask-ml 作者: dask 项目源码 文件源码
def _kmeans_single_lloyd(X, n_clusters, max_iter=300, init='k-means||',
                         verbose=False, x_squared_norms=None,
                         random_state=None, tol=1e-4,
                         precompute_distances=True,
                         oversampling_factor=2,
                         init_max_iter=None):
    centers = k_init(X, n_clusters, init=init,
                     oversampling_factor=oversampling_factor,
                     random_state=random_state, max_iter=init_max_iter)
    dt = X.dtype
    P = X.shape[1]
    for i in range(max_iter):
        t0 = tic()
        labels, distances = pairwise_distances_argmin_min(
            X, centers, metric='euclidean', metric_kwargs={"squared": True}
        )

        labels = labels.astype(np.int32)
        # distances is always float64, but we need it to match X.dtype
        # for centers_dense, but remain float64 for inertia
        r = da.atop(_centers_dense, 'ij',
                    X, 'ij',
                    labels, 'i',
                    n_clusters, None,
                    distances.astype(X.dtype), 'i',
                    adjust_chunks={"i": n_clusters, "j": P},
                    dtype=X.dtype)
        new_centers = da.from_delayed(
            sum(r.to_delayed().flatten()),
            (n_clusters, P),
            X.dtype
        )
        counts = da.bincount(labels, minlength=n_clusters)
        # Require at least one per bucket, to avoid division by 0.
        counts = da.maximum(counts, 1)
        new_centers = new_centers / counts[:, None]
        new_centers, = compute(new_centers)

        # Convergence check
        shift = squared_norm(centers - new_centers)
        t1 = tic()
        logger.info("Lloyd loop %2d. Shift: %0.4f [%.2f s]", i, shift, t1 - t0)
        if shift < tol:
            break
        centers = new_centers

    if shift > 1e-7:
        labels, distances = pairwise_distances_argmin_min(X, centers)
        labels = labels.astype(np.int32)

    inertia = distances.sum()
    centers = centers.astype(dt)

    return labels, inertia, centers, i + 1
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号