def load_data():
print 'Loading data...'
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255
y_train = np_utils.to_categorical(y_train, 10)
y_test = np_utils.to_categorical(y_test, 10)
X_train = np.reshape(X_train, (60000, 784))
X_test = np.reshape(X_test, (10000, 784))
print 'Data loaded.'
return [X_train, X_test, y_train, y_test]
评论列表
文章目录