def plot_classes(self, points, clusters, e, fig=0):
e = e[0].data.cpu().numpy()
points = points[0]
plt.figure(fig)
plt.clf()
colors = cm.rainbow(np.linspace(0, 1, clusters))
for cl in range(clusters):
ind = np.where(e == cl)[0]
pts = points[ind]
plt.scatter(pts[:, 0], pts[:, 1], c=colors[cl])
plt.title('clustering')
path = os.path.join(self.path, 'clustering_ex.png'.format(clusters))
plt.savefig(path)
评论列表
文章目录