def mean_shift(location, location_callback, bandwidth=None):
"""Returns one or more clusters of a set of points, using a mean shift
algorithm.
The result is sorted with the first value being the largest cluster.
Kwargs:
bandwidth (float): If bandwidth is None, a value is detected
automatically from the input using estimate_bandwidth.
Returns:
A list of NamedTuples (see get_cluster_named_tuple for a definition
of the tuple).
"""
pts = location._tuple_points()
if not pts:
return None
X = np.array(pts).reshape((len(pts), len(pts[0])))
if np.any(np.isnan(X)) or not np.all(np.isfinite(X)):
return None
X = Imputer().fit_transform(X)
X = X.astype(np.float32)
if not bandwidth:
bandwidth = estimate_bandwidth(X, quantile=0.3)
ms = MeanShift(bandwidth=bandwidth or None, bin_seeding=False).fit(X)
clusters = []
for cluster_id, cluster_centre in enumerate(ms.cluster_centers_):
locations = []
for j, label in enumerate(ms.labels_):
if not label == cluster_id:
continue
locations.append(location.locations[j])
if not locations:
continue
clusters.append(cluster_named_tuple()(label=cluster_id,
centroid=Point(cluster_centre),
location=location_callback(
locations)))
return clusters
评论列表
文章目录