load_data.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:Kiddo 作者: Subarno 项目源码 文件源码
def visualize(X, Y, classes, samples_per_class=10):
    nb_classes = len(classes)

    for y, cls in enumerate(classes):
        idxs = np.flatnonzero(Y == y)
        idxs = np.random.choice(idxs, samples_per_class, replace=False)

        for i, idx in enumerate(idxs):
            plt_idx = i * nb_classes + y + 1
            plt.subplot(samples_per_class, nb_classes, plt_idx)
            plt.imshow(X[idx], cmap='gray')
            plt.axis('off')
            if i == 0:
                plt.title(cls)
    #plt.show()
    plt.savefig('img/data.png')
    plt.clf()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号