def mean_shift(fig):
global X_iris, geo
ax = fig.add_subplot(geo + 4, projection='3d', title='mean_shift')
bandwidth = cluster.estimate_bandwidth(X_iris, quantile=0.2, n_samples=50)
mean_shift = cluster.MeanShift(bandwidth=bandwidth, bin_seeding=True)
mean_shift.fit(X_iris)
res = mean_shift.labels_
for n, i in enumerate(X_iris):
ax.scatter(*i[: 3], c='bgrcmyk'[res[n] % 7], marker='o')
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')
return res
评论列表
文章目录