def get_mnist(n_train=5000, n_test=500, pca=True, d=50, dtype=np.float32):
(X_train, y_train), (X_test, y_test) = mnist.load_data()
n, row, col = X_train.shape
channel = 1
X_train = X_train.reshape(-1, channel * row * col)
X_test = X_test.reshape(-1, channel * row * col)
X_train = X_train.astype(dtype)
X_test = X_test.astype(dtype)
X_train /= 255
X_test /= 255
X_train = X_train[:n_train] - X_train[:n_train].mean(axis=0)
X_test = X_test[:n_test] - X_test[:n_test].mean(axis=0)
if pca:
pcfit = PCA(n_components=d)
X_train = pcfit.fit_transform(X_train)
X_test = pcfit.transform(X_test)
y_train = y_train[:n_train]
y_test = y_test[:n_test]
return X_train, y_train, X_test, y_test
评论列表
文章目录