def get_mnist(image_size):
mnist = fetch_mldata('MNIST original')
np.random.seed(1234) # set seed for deterministic ordering
p = np.random.permutation(mnist.data.shape[0])
X = mnist.data[p]
X = X.reshape((70000, 1, image_size, image_size))
Y = mnist.target[p]
X = X.astype(np.float32)/(255.0/2) - 1.0
X_train = X[:60000]
X_test = X[60000:]
Y_train = Y[:60000]
Y_test = Y[60000:]
return X_train, X_test, Y_train, Y_test
评论列表
文章目录