def load_data():
# Load the data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
trainX, trainY, testX, testY = mnist.train.images, mnist.train.labels, \
mnist.test.images, mnist.test.labels
trainX = trainX.reshape(-1, 28, 28, 1)
testX = testX.reshape(-1, 28, 28, 1)
return trainX, trainY, testX, testY
评论列表
文章目录