mnist_3d.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:huaat_ml_dl 作者: ieee820 项目源码 文件源码
def save_2d(label):
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    l_z,l_x,l_y = X_train.shape
    #cubes = np.ndarray([10,28,28],dtype=np.uint8)
    #new_1 = np.random(28,28)
    new_all = np.ones(784)
    new_all.resize(28,28)
    j = 1
    for i in range(0, l_z):
        #print X_train[i,:,:],y_train[i]
        #if j >= 10:
            #break;
        new = X_train[i,:,:]
        if y_train[i] == label :
            new_all = np.concatenate((new_all,new),axis=0)
            j = j +1

    #reshape and save
    new_all.resize(j,28,28)
    new_mini = new_all[1:,:,:]

    np.save('/home/yangjj/minist_npy/'+str(label),new_mini)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号