def scatter_classes(x, classes, ax=None):
"""Scatter the data points coloring by the classes."""
if ax is None:
_fig, ax = plt.subplots()
ax = plt.gca() if ax is None else ax
cmap = matplotlib.cm.jet
norm = matplotlib.colors.Normalize(
vmin=np.min(classes), vmax=np.max(classes))
mapper = matplotlib.cm.ScalarMappable(cmap=cmap, norm=norm)
colors = mapper.to_rgba(classes)
ax.scatter(x[:,0], x[:,1], color=colors)
return ax
评论列表
文章目录