def visual_data(self):
print "visualized some cifar-10 picture..."
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
num_classes = len(classes)
samples_per_class = 7
for y, cl in enumerate(classes):
idxs = np.flatnonzero(self.y_train == y)
idxs = np.random.choice(idxs, samples_per_class, replace=False)
for i, idx in enumerate(idxs):
plt_idx = i * num_classes + y + 1
plt.subplot(samples_per_class, num_classes, plt_idx)
plt.imshow(self.X_train[idx].astype('uint8'))
plt.axis('off')
if i == 0:
plt.title(cl)
plt.show()
print "visualized data done...\n---------\n"
#
评论列表
文章目录