def plot_silhouettes(X, y):
cluster_labels = np.unique(y)
n_clusters = cluster_labels.shape[0]
silhouette_vals = silhouette_samples(X, y, metric='euclidean')
y_ax_lower = 0
y_ax_upper = 0
yticks = []
for i, c in enumerate(cluster_labels):
c_silhouette_vals = silhouette_vals[y == c]
c_silhouette_vals.sort()
y_ax_upper += len(c_silhouette_vals)
color = cm.jet(i / n_clusters)
plt.barh(
range(y_ax_lower, y_ax_upper),
c_silhouette_vals,
height=1.0,
edgecolor='none',
color=color,
)
yticks.append((y_ax_lower + y_ax_upper) / 2)
y_ax_lower += len(c_silhouette_vals)
silhouette_avg = np.mean(silhouette_vals)
plt.axvline(silhouette_avg, color='red', linestyle='--')
plt.yticks(yticks, cluster_labels + 1)
plt.ylabel('Cluster')
plt.xlabel('Silhouette coefficient')
plt.show()
chapter_11.py 文件源码
python
阅读 25
收藏 0
点赞 0
评论 0
评论列表
文章目录