def scatter(x, colors):
# We choose a color palette with seaborn.
palette = np.array(sea.color_palette("hls", 258))
# We create a scatter plot.
f = plt.figure(figsize=(8, 8))
ax = plt.subplot(aspect='equal')
sc = ax.scatter(x[:, 0], x[:, 1], lw=0, s=40,
c=palette[colors.astype(np.int)])
plt.xlim(-25, 25)
plt.ylim(-25, 25)
ax.axis('off')
ax.axis('tight')
# We add the labels for each digit.
txts = []
for i in range(10):
# Position of each label.
xtext, ytext = np.median(x[colors == i, :], axis=0)
txt = ax.text(xtext, ytext, str(i), fontsize=24)
txt.set_path_effects([
patheffects.Stroke(linewidth=5, foreground="w"),
patheffects.Normal()])
txts.append(txt)
plt.show()
return f, ax, sc, txts
评论列表
文章目录