def visualize(X, Y, classes, samples_per_class=10):
nb_classes = len(classes)
for y, cls in enumerate(classes):
idxs = np.flatnonzero(Y == y)
idxs = np.random.choice(idxs, samples_per_class, replace=False)
for i, idx in enumerate(idxs):
plt_idx = i * nb_classes + y + 1
plt.subplot(samples_per_class, nb_classes, plt_idx)
plt.imshow(X[idx], cmap='gray')
plt.axis('off')
if i == 0:
plt.title(cls)
#plt.show()
plt.savefig('img/data.png')
plt.clf()
评论列表
文章目录