def save_2d(label):
(X_train, y_train), (X_test, y_test) = mnist.load_data()
l_z,l_x,l_y = X_train.shape
#cubes = np.ndarray([10,28,28],dtype=np.uint8)
#new_1 = np.random(28,28)
new_all = np.ones(784)
new_all.resize(28,28)
j = 1
for i in range(0, l_z):
#print X_train[i,:,:],y_train[i]
#if j >= 10:
#break;
new = X_train[i,:,:]
if y_train[i] == label :
new_all = np.concatenate((new_all,new),axis=0)
j = j +1
#reshape and save
new_all.resize(j,28,28)
new_mini = new_all[1:,:,:]
np.save('/home/yangjj/minist_npy/'+str(label),new_mini)
评论列表
文章目录