test.py 文件源码

python
阅读 16 收藏 0 点赞 0 评论 0

项目:tensorsne 作者: gokceneraslan 项目源码 文件源码
def get_mnist(n_train=5000, n_test=500, pca=True, d=50, dtype=np.float32):
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    n, row, col = X_train.shape
    channel = 1

    X_train = X_train.reshape(-1, channel * row * col)
    X_test = X_test.reshape(-1, channel * row * col)
    X_train = X_train.astype(dtype)
    X_test = X_test.astype(dtype)
    X_train /= 255
    X_test /= 255

    X_train = X_train[:n_train] - X_train[:n_train].mean(axis=0)
    X_test = X_test[:n_test] - X_test[:n_test].mean(axis=0)

    if pca:
        pcfit = PCA(n_components=d)

        X_train = pcfit.fit_transform(X_train)
        X_test = pcfit.transform(X_test)

    y_train = y_train[:n_train]
    y_test = y_test[:n_test]

    return X_train, y_train, X_test, y_test
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号