def get_mnist_data(binarize=False):
"""Puts the MNIST data in the right format."""
(X_train, y_train), (X_test, y_test) = mnist.load_data()
if binarize:
X_test = np.where(X_test >= 10, 1, -1)
X_train = np.where(X_train >= 10, 1, -1)
else:
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_test = (X_test.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=-1)
X_test = np.expand_dims(X_test, axis=-1)
y_train = np.eye(10)[y_train]
y_test = np.eye(10)[y_test]
return (X_train, y_train), (X_test, y_test)
评论列表
文章目录