def plot(self, ax=None, holdon=False):
sns.set(style="white")
data = self.X
if ax is None:
_, ax = plt.subplots()
for i, index in enumerate(self.clusters):
point = np.array(data[index]).T
ax.scatter(*point, c=sns.color_palette("hls", self.K + 1)[i])
for point in self.centroids:
ax.scatter(*point, marker='x', linewidths=10)
if not holdon:
plt.show()
评论列表
文章目录