def plot_representations(X, y, title):
"""Plot distributions and thier labels."""
x_min, x_max = np.min(X, 0), np.max(X, 0)
X = (X - x_min) / (x_max - x_min)
f = plt.figure(figsize=(15, 10.8), dpi=300)
# ax = plt.subplot(111)
for i in range(X.shape[0]):
plt.text(X[i, 0], X[i, 1], str(y[i]),
color=plt.cm.Set1(y[i] / 10.),
fontdict={'weight': 'bold', 'size': 9})
# if hasattr(offsetbox, 'AnnotationBbox'):
# # only print thumbnails with matplotlib > 1.0
# shown_images = np.array([[1., 1.]]) # just something big
# for i in range(digits.data.shape[0]):
# dist = np.sum((X[i] - shown_images) ** 2, 1)
# if np.min(dist) < 4e-3:
# # don't show points that are too close
# continue
# shown_images = np.r_[shown_images, [X[i]]]
# imagebox = offsetbox.AnnotationBbox(
# offsetbox.OffsetImage(digits.images[i], cmap=plt.cm.gray_r),
# X[i])
# ax.add_artist(imagebox)
plt.xticks([]), plt.yticks([])
if title is not None:
plt.title(title)
return f
tools.py 文件源码
python
阅读 26
收藏 0
点赞 0
评论 0
评论列表
文章目录