python类fetch_mldata()的实例源码

mnist.py 文件源码 项目:mlens 作者: flennerhag 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def load_data(dtype=np.float32, order='F'):
    """Load the data, then cache and memmap the train/test split"""
    ######################################################################
    # Load dataset
    safe_print("Loading dataset...")
    data = fetch_mldata('MNIST original')
    X = check_array(data['data'], dtype=dtype, order=order)
    y = data["target"]

    # Normalize features
    X = X / 255

    # Create train-test split (as [Joachims, 2006])
    safe_print("Creating train-test split...")
    n_train = 60000
    X_train = X[:n_train]
    y_train = y[:n_train]
    X_test = X[n_train:]
    y_test = y[n_train:]

    return X_train, X_test, y_train, y_test
base.py 文件源码 项目:impyute 作者: eltonlaw 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def mnist(missingness="mcar", thr=0.2):
    """ Loads corrupted MNIST

    Parameters
    ----------
    missingness: ('mcar', 'mar', 'mnar')
        Type of missigness you want in your dataset
    th: float between [0,1]
        Percentage of missing data in generated data

    Returns
    -------
    numpy.ndarray
    """
    from sklearn.datasets import fetch_mldata
    dataset = fetch_mldata('MNIST original')
    corruptor = Corruptor(dataset.data, thr=thr)
    data = getattr(corruptor, missingness)()
    return {"X": data, "Y": dataset.target}
CO2_1d_regression.py 文件源码 项目:SCFGP 作者: MaxInGaussian 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def load_co2_data(prop=0.8):
    from sklearn.datasets import fetch_mldata
    from sklearn import cross_validation
    data = fetch_mldata('mauna-loa-atmospheric-co2').data
    X = data[:, [1]]
    y = data[:, 0]
    y = y[:, None]
    X = X.astype(np.float64)
    ntrain = y.shape[0]
    train_inds = npr.choice(range(ntrain), int(prop*ntrain), replace=False)
    valid_inds = np.setdiff1d(range(ntrain), train_inds)
    X_train, y_train = X[train_inds].copy(), y[train_inds].copy()
    X_valid, y_valid = X[valid_inds].copy(), y[valid_inds].copy()
    return X_train, y_train, X_valid, y_valid

############################ Training & Visualizing ############################
knn_train_mnist.py 文件源码 项目:cv_ml 作者: techfort 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def get_trained_knn():
    print("Training k-NN classifier for MNIST dataset")
    mnist = fetch_mldata("MNIST original")
    KNN = cv2.ml.KNearest_create()
    traindata, trainlabels = [], []

    # populate labels
    for k in mnist.target:
        trainlabels.append(k)

    # populate images
    for d in mnist.data:
        traindata.append(np.array(d, dtype=np.float32))

    # train the model
    KNN.train(np.array(traindata), cv2.ml.ROW_SAMPLE, np.array(trainlabels, dtype=np.int32))
    # KNN.save("hwdigits.xml")
    return KNN
dcgan.py 文件源码 项目:mxnet_tk1 作者: starimpact 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def get_mnist():
    mnist = fetch_mldata('MNIST original')
    np.random.seed(1234) # set seed for deterministic ordering
    p = np.random.permutation(mnist.data.shape[0])
    X = mnist.data[p]
    X = X.reshape((70000, 28, 28))

    X = np.asarray([cv2.resize(x, (64,64)) for x in X])

    X = X.astype(np.float32)/(255.0/2) - 1.0
    X = X.reshape((70000, 1, 64, 64))
    X = np.tile(X, (1, 3, 1, 1))
    X_train = X[:60000]
    X_test = X[60000:]

    return X_train, X_test
bench_mnist.py 文件源码 项目:Parallel-SGD 作者: angadgill 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def load_data(dtype=np.float32, order='F'):
    """Load the data, then cache and memmap the train/test split"""
    ######################################################################
    ## Load dataset
    print("Loading dataset...")
    data = fetch_mldata('MNIST original')
    X = check_array(data['data'], dtype=dtype, order=order)
    y = data["target"]

    # Normalize features
    X = X / 255

    ## Create train-test split (as [Joachims, 2006])
    print("Creating train-test split...")
    n_train = 60000
    X_train = X[:n_train]
    y_train = y[:n_train]
    X_test = X[n_train:]
    y_test = y[n_train:]

    return X_train, X_test, y_train, y_test
test_mldata.py 文件源码 项目:Parallel-SGD 作者: angadgill 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def test_download():
    """Test that fetch_mldata is able to download and cache a data set."""

    _urlopen_ref = datasets.mldata.urlopen
    datasets.mldata.urlopen = mock_mldata_urlopen({
        'mock': {
            'label': sp.ones((150,)),
            'data': sp.ones((150, 4)),
        },
    })
    try:
        mock = fetch_mldata('mock', data_home=tmpdir)
        for n in ["COL_NAMES", "DESCR", "target", "data"]:
            assert_in(n, mock)

        assert_equal(mock.target.shape, (150,))
        assert_equal(mock.data.shape, (150, 4))

        assert_raises(datasets.mldata.HTTPError,
                      fetch_mldata, 'not_existing_name')
    finally:
        datasets.mldata.urlopen = _urlopen_ref
test_mldata.py 文件源码 项目:Parallel-SGD 作者: angadgill 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def test_fetch_one_column():
    _urlopen_ref = datasets.mldata.urlopen
    try:
        dataname = 'onecol'
        # create fake data set in cache
        x = sp.arange(6).reshape(2, 3)
        datasets.mldata.urlopen = mock_mldata_urlopen({dataname: {'x': x}})

        dset = fetch_mldata(dataname, data_home=tmpdir)
        for n in ["COL_NAMES", "DESCR", "data"]:
            assert_in(n, dset)
        assert_not_in("target", dset)

        assert_equal(dset.data.shape, (2, 3))
        assert_array_equal(dset.data, x)

        # transposing the data array
        dset = fetch_mldata(dataname, transpose_data=False, data_home=tmpdir)
        assert_equal(dset.data.shape, (3, 2))
    finally:
        datasets.mldata.urlopen = _urlopen_ref
mlp-mnist.py 文件源码 项目:NumpyDL 作者: oujago 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def get_data():
    # data
    print("loading data, please wait ...")
    mnist = fetch_mldata('MNIST original', data_home=os.path.join(os.path.dirname(__file__), './data'))
    print('data loading is done ...')
    X_train = mnist.data / 255.0
    y_train = mnist.target
    n_classes = np.unique(y_train).size

    return n_classes, X_train, y_train
cnn-minist.py 文件源码 项目:NumpyDL 作者: oujago 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def main(max_iter):
    seed = 100
    nb_data = 1000

    print("loading data ....")
    mnist = fetch_mldata('MNIST original', data_home=os.path.join(os.path.dirname(__file__), './data'))
    X_train = mnist.data.reshape((-1, 1, 28, 28)) / 255.0
    np.random.seed(seed)
    X_train = np.random.permutation(X_train)[:nb_data]
    y_train = mnist.target
    np.random.seed(seed)
    y_train = np.random.permutation(y_train)[:nb_data]
    n_classes = np.unique(y_train).size

    print("building model ...")
    net = npdl.Model()
    net.add(npdl.layers.Convolution(1, (3, 3), input_shape=(None, 1, 28, 28)))
    net.add(npdl.layers.MeanPooling((2, 2)))
    net.add(npdl.layers.Convolution(2, (4, 4)))
    net.add(npdl.layers.MeanPooling((2, 2)))
    net.add(npdl.layers.Flatten())
    net.add(npdl.layers.Softmax(n_out=n_classes))
    net.compile()

    print("train model ... ")
    net.fit(X_train, npdl.utils.data.one_hot(y_train), max_iter=max_iter, validation_split=0.1, batch_size=100)
BiRNN_mnist.py 文件源码 项目:Project101 作者: Wonjuseo 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def data_read():
    return datasets.fetch_mldata('MNIST original',data_home='.')
main.py 文件源码 项目:deeplearning 作者: turiphro 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def mnist():
    #digits = datasets.load_digits() # subsampled version
    mnist = datasets.fetch_mldata("MNIST original")
    print("Got the data.")
    X, y = mnist.data / 255., mnist.target
    X_train, X_test = X[:60000], X[60000:]
    y_train, y_test = y[:60000], y[60000:]

    #images_and_labels = list(zip(digits.images, digits.target))
    #for index, (image, label) in enumerate(images_and_labels[:4]):
    #    plt.subplot(2, 4, index + 1)
    #    plt.axis('off')
    #    plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
    #    plt.title('Training: %i' % label)

    classifiers = [
        #("SVM", svm.SVC(gamma=0.001)), # TODO doesn't finish; needs downsampled version?
        ("NN", MLPClassifier(hidden_layer_sizes=(50,), max_iter=10, alpha=1e-4,
                             solver='sgd', verbose=10, tol=1e-4, random_state=1,
                             learning_rate_init=.1)),
    ]

    for name, classifier in classifiers:
        print(name)
        classifier.fit(X_train, y_train)
        predicted = classifier.predict(X_test)

        print("Classification report for classifier %s:\n%s\n"
              % (classifier, metrics.classification_report(y_test, predicted)))
        print("Confusion matrix:\n%s" % metrics.confusion_matrix(y_test, predicted))

        #images_and_predictions = list(zip(digits.images[n_samples / 2:], predicted))
        #for index, (image, prediction) in enumerate(images_and_predictions[:4]):
        #    plt.subplot(2, 4, index + 5)
        #    plt.axis('off')
        #    plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
        #    plt.title('Prediction: %i' % prediction)

        #plt.show()
data.py 文件源码 项目:a3c 作者: siemanko 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def __init__(self, batch_size, validation_size):
        self.batch_size = batch_size

        # Load MNIST
        mnist = fetch_mldata('MNIST original')
        X, Y_labels = mnist['data'], mnist['target']

        # normalize X to (0.0, 1.0) range
        X = X.astype(np.float32) / 255.0

        # one hot encode the labels
        Y = np.zeros((len(Y_labels), 10))
        Y[range(len(Y_labels)), Y_labels.astype(np.int32)] = 1.

        # ensure type is float32
        X = X.astype(np.float32)
        Y = Y.astype(np.float32)

        # shuffle examples
        permutation = np.random.permutation(len(X))
        X = X[permutation]
        Y = Y[permutation]

        # split into train, validate, test
        train_end      = 60000 - validation_size
        validation_end = 60000
        test_end       = 70000

        self.X_train = X[0:train_end]
        self.X_valid = X[train_end:validation_end]
        self.X_test  = X[validation_end:test_end]

        self.Y_train = Y[0:train_end]
        self.Y_valid = Y[train_end:validation_end]
        self.Y_test  = Y[validation_end:test_end]
kin8nm.py 文件源码 项目:SCFGP 作者: MaxInGaussian 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def load_kin8nm_data(proportion=3192./8192):
    from sklearn import datasets
    from sklearn import cross_validation
    kin8nm = datasets.fetch_mldata('regression-datasets kin8nm')
    X, y = kin8nm.data[:, :-1], kin8nm.data[:, -1]
    y = y[:, None]
    X = X.astype(np.float64)
    X_train, X_test, y_train, y_test = \
        cross_validation.train_test_split(X, y, test_size=proportion)
    return X_train, y_train, X_test, y_test
test_regression.py 文件源码 项目:SCFGP 作者: MaxInGaussian 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def load_kin8nm_data(proportion=3192./8192):
    from sklearn import datasets
    from sklearn import cross_validation
    kin8nm = datasets.fetch_mldata('regression-datasets kin8nm')
    X, y = kin8nm.data[:, :-1], kin8nm.data[:, -1]
    y = y[:, None]
    X = X.astype(np.float64)
    X_train, X_test, y_train, y_test = \
        cross_validation.train_test_split(X, y, test_size=proportion)
    return X_train, y_train, X_test, y_test
test_regression.py 文件源码 项目:SCFGP 作者: MaxInGaussian 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def load_abalone_data(proportion=1044./4177):
    from sklearn import datasets
    from sklearn import preprocessing
    from sklearn import cross_validation
    abalone = datasets.fetch_mldata('regression-datasets abalone')
    X_cate = np.array([abalone.target[i].tolist()
        for i in range(abalone.target.shape[0])])
    X_cate = preprocessing.label_binarize(X_cate, np.unique(X_cate))
    X = np.hstack((X_cate, abalone.data))
    y = abalone.int1[0].T.astype(np.float64)
    y = y[:, None]
    X = X.astype(np.float64)
    X_train, X_test, y_train, y_test = \
        cross_validation.train_test_split(X, y, test_size=proportion)
    return X_train, y_train, X_test, y_test
abalone.py 文件源码 项目:SCFGP 作者: MaxInGaussian 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def load_abalone_data(proportion=1044./4177):
    from sklearn import datasets
    from sklearn import preprocessing
    from sklearn import cross_validation
    abalone = datasets.fetch_mldata('regression-datasets abalone')
    X_cate = np.array([abalone.target[i].tolist()
        for i in range(abalone.target.shape[0])])
    X_cate = preprocessing.label_binarize(X_cate, np.unique(X_cate))
    X = np.hstack((X_cate, abalone.data))
    y = abalone.int1[0].T.astype(np.float64)
    y = y[:, None]
    X = X.astype(np.float64)
    X_train, X_test, y_train, y_test = \
        cross_validation.train_test_split(X, y, test_size=proportion)
    return X_train, y_train, X_test, y_test
real.py 文件源码 项目:sdp_kmeans 作者: simonsfoundation 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def mnist(digit='all', n_samples=0, return_gt=False):
    mnist = sk_datasets.fetch_mldata('MNIST original')
    X = mnist.data
    gt = mnist.target

    if digit == 'all':  # keep all digits
        pass
    else:
        X = X[gt == digit, :]
        gt = gt[gt == digit]

    if n_samples > len(X):
        raise ValueError('Requesting {} samples'
                         'from {} datapoints'.format(n_samples, len(X)))
    if n_samples > 0:
        np.random.seed(0)
        selection = np.random.randint(len(X), size=n_samples)
        X = X[selection, :]
        gt = gt[selection]

        idx = np.argsort(gt)
        X = X[idx, :]
        gt = gt[idx]

    if return_gt:
        return X, gt
    else:
        return X
uci_loader.py 文件源码 项目:sklearn-random-bits-forest 作者: tmadl 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def getdataset(datasetname, onehot_encode_strings=True):
    # load
    dataset = fetch_mldata(datasetname)
    # get X and y
    X = dshape(dataset.data)
    try:
        target = dshape(dataset.target)
    except:
        print "WARNING: No target found. Taking last column of data matrix as target"
        target = X[:, -1]
        X = X[:, :-1]
    if len(target.shape)>1 and target.shape[1]>X.shape[1]: # some mldata sets are mixed up...
        X = target
        target = dshape(dataset.data)
    if len(X.shape) == 1 or X.shape[1] <= 1:
        for k in dataset.keys():
            if k != 'data' and k != 'target' and len(dataset[k]) == X.shape[1]:
                X = np.hstack((X, dshape(dataset[k])))
    # one-hot for categorical values
    if onehot_encode_strings:
        cat_ft=[i for i in range(X.shape[1]) if 'str' in str(type(unpack(X[0,i]))) or 'unicode' in str(type(unpack(X[0,i])))]
        if len(cat_ft): 
            for i in cat_ft:
                X[:,i] = tonumeric(X[:,i]) 
            X = OneHotEncoder(categorical_features=cat_ft).fit_transform(X)
    # if sparse, make dense
    try:
        X = X.toarray()
    except:
        pass
    # convert y to monotonically increasing ints
    y = tonumeric(target).astype(int)
    return np.nan_to_num(X.astype(float)),y
WGAN_mnist.py 文件源码 项目:Keras-GAN 作者: Shaofanl 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def get_mnist(nbatch=128):
    mnist = fetch_mldata('MNIST original', data_home='/home/shaofan/.sklearn/') 
    x, y = mnist.data, mnist.target
    x = x.reshape(-1, 1, 28, 28)

    ind = np.random.permutation(x.shape[0])
    x = x[ind]
    y = y[ind]

    def random_stream():
        while 1:
            yield x[np.random.choice(x.shape[0], replace=False, size=nbatch)].transpose(0, 2, 3, 1)
    return x, y, random_stream
infogan_mnist.py 文件源码 项目:Keras-GAN 作者: Shaofanl 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def get_mnist(nbatch=128):
    mnist = fetch_mldata('MNIST original', data_home='/home/shaofan/.sklearn/') 
    x, y = mnist.data, mnist.target
    x = x.reshape(-1, 1, 28, 28)

    ind = np.random.permutation(x.shape[0])
    x = x[ind]
    y = y[ind]

    def random_stream():
        while 1:
            yield x[np.random.choice(x.shape[0], replace=False, size=nbatch)].transpose(0, 2, 3, 1)
    return x, y, random_stream
aegan_mnist.py 文件源码 项目:Keras-GAN 作者: Shaofanl 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def get_mnist(nbatch=128):
    mnist = fetch_mldata('MNIST original', data_home='/home/shaofan/.sklearn/') 
    x, y = mnist.data, mnist.target
    x = x.reshape(-1, 1, 28, 28)

    ind = np.random.permutation(x.shape[0])
    x = x[ind]
    y = y[ind]

    def random_stream():
        while 1:
            yield x[np.random.choice(x.shape[0], replace=False, size=nbatch)].transpose(0, 2, 3, 1)
    return x, y, random_stream
init_with_ae.py 文件源码 项目:Keras-GAN 作者: Shaofanl 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def get_mnist(nbatch=128):
    mnist = fetch_mldata('MNIST original', data_home='/home/shaofan/.sklearn/') 
    x, y = mnist.data, mnist.target
    x = x.reshape(-1, 1, 28, 28)

    ind = np.random.permutation(x.shape[0])
    x = x[ind]
    y = y[ind]

    def random_stream():
        while 1:
            yield x[np.random.choice(x.shape[0], replace=False, size=nbatch)].transpose(0, 2, 3, 1)
    return x, y, random_stream
mnist.py 文件源码 项目:Keras-GAN 作者: Shaofanl 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def get_mnist(nbatch=128):
    mnist = fetch_mldata('MNIST original', data_home='/home/shaofan/.sklearn/') 
    x, y = mnist.data, mnist.target
    x = x.reshape(-1, 1, 28, 28)

    ind = np.random.permutation(x.shape[0])
    x = x[ind]
    y = y[ind]

    def random_stream():
        while 1:
            yield x[np.random.choice(x.shape[0], replace=False, size=nbatch)].transpose(0, 2, 3, 1)
    return x, y, random_stream
data_fetch.py 文件源码 项目:pylmnn 作者: johny-c 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def fetch_from_config(cfg):
    data_set_name = cfg['fetch']['name']
    if cfg['fetch'].getboolean('sklearn'):
        if data_set_name == 'OLIVETTI':
            data_set = skd.fetch_olivetti_faces(shuffle=True)
        else:
            data_set = skd.fetch_mldata(data_set_name)
        X, y = data_set.data, data_set.target
        if data_set_name == 'MNIST original':
            if cfg['pre_process'].getboolean('normalize'):
                X = X / 255.
    else:
        if data_set_name == 'LETTERS':
            X, y = fetch_load_letters()
        elif data_set_name == 'ISOLET':
            x_tr, x_te, y_tr, y_te = fetch_load_isolet()
        elif data_set_name == 'SHREC14':
            X, y = load_shrec14(real=cfg['fetch']['real'], desc=cfg['fetch']['desc'])
            X = prep.normalize(X, norm=cfg['pre_process']['norm'])
        else:
            raise NameError('No data set {} found!'.format(data_set_name))

    # Separate training and testing set
    if data_set_name == 'MNIST original':
        x_tr, x_te, y_tr, y_te = X[:60000], X[60000:], y[:60000], y[60000:]
    elif data_set_name != 'ISOLET':
        test_size = cfg['train_test'].getfloat('test_size')
        x_tr, x_te, y_tr, y_te = train_test_split(X, y, test_size=test_size, stratify=y)

    return x_tr, x_te, y_tr, y_te
train_mnist.py 文件源码 项目:soinn 作者: fukatani 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def prepare_dataset():
    print('load MNIST dataset')
    mnist = fetch_mldata('MNIST original')
    mnist['data'] = mnist['data'].astype(np.float32)
    mnist['data'] /= 255
    mnist['target'] = mnist['target'].astype(np.int32)
    return mnist
utils.py 文件源码 项目:chainer-adversarial-autoencoder 作者: fukuta0614 项目源码 文件源码 阅读 42 收藏 0 点赞 0 评论 0
def load_mnist():
    mnist = fetch_mldata('MNIST original')
    mnist_X, mnist_y = shuffle(mnist.data.astype('float32'), mnist.target.astype('int32'), random_state=1234)

    mnist_X /=  255.
    mnist_y = np.eye(10)[mnist_y].astype('int32')
    x_train, x_test, y_train, y_test = train_test_split(mnist_X, mnist_y, test_size=0.2, random_state=1234)

    return x_train, x_test, y_train, y_test
gan.py 文件源码 项目:mxnet-wgan 作者: vsooda 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def get_mnist(image_size):
    mnist = fetch_mldata('MNIST original')
    np.random.seed(1234) # set seed for deterministic ordering
    p = np.random.permutation(mnist.data.shape[0])
    X = mnist.data[p]
    X = X.reshape((70000, 1, image_size, image_size))
    Y = mnist.target[p]

    X = X.astype(np.float32)/(255.0/2) - 1.0
    X_train = X[:60000]
    X_test = X[60000:]
    Y_train = Y[:60000]
    Y_test = Y[60000:]

    return X_train, X_test, Y_train, Y_test
data.py 文件源码 项目:mxnet_tk1 作者: starimpact 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def get_mnist():
    np.random.seed(1234) # set seed for deterministic ordering
    data_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
    data_path = os.path.join(data_path, '../../data')
    mnist = fetch_mldata('MNIST original', data_home=data_path)
    p = np.random.permutation(mnist.data.shape[0])
    X = mnist.data[p].astype(np.float32)*0.02
    Y = mnist.target[p]
    return X, Y
tune_params.py 文件源码 项目:SRU 作者: akuzeee 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def load_mnist():
    mnist = fetch_mldata('MNIST original')
    mnist_X, mnist_y = shuffle(mnist.data, mnist.target, random_state=seed)
    mnist_X = mnist_X / 255.0

    # pytorch?????
    mnist_X, mnist_y = mnist_X.astype('float32'), mnist_y.astype('int64')

    # 2?????????????????1?????
    def flatten_img(images):
        '''
        images: shape => (n, rows, columns)
        output: shape => (n, rows*columns)
        '''
        n_rows    = images.shape[1]
        n_columns = images.shape[2]
        for num in range(n_rows):
            if num % 2 != 0:
                images[:, num, :] = images[:, num, :][:, ::-1]
        output = images.reshape(-1, n_rows*n_columns)
        return output

    mnist_X = mnist_X.reshape(-1, 28, 28)
    mnist_X = flatten_img(mnist_X) # X.shape => (n_samples, seq_len)
    mnist_X = mnist_X[:, :, np.newaxis] # X.shape => (n_samples, seq_len, n_features)

    # ????????????
    train_X, test_X, train_y, test_y = train_test_split(mnist_X, mnist_y,
                                                        test_size=0.2,
                                                        random_state=seed)
    return train_X, test_X, train_y, test_y


问题


面经


文章

微信
公众号

扫码关注公众号