main.py 文件源码

python
阅读 15 收藏 0 点赞 0 评论 0

项目:SRU 作者: akuzeee 项目源码 文件源码
def load_mnist():
    mnist = fetch_mldata('MNIST original')
    mnist_X, mnist_y = shuffle(mnist.data, mnist.target, random_state=seed)
    mnist_X = mnist_X / 255.0

    # pytorch?????
    mnist_X, mnist_y = mnist_X.astype('float32'), mnist_y.astype('int64')

    # 2?????????????????1?????
    def flatten_img(images):
        '''
        images: shape => (n, rows, columns)
        output: shape => (n, rows*columns)
        '''
        n_rows    = images.shape[1]
        n_columns = images.shape[2]
        for num in range(n_rows):
            if num % 2 != 0:
                images[:, num, :] = images[:, num, :][:, ::-1]
        output = images.reshape(-1, n_rows*n_columns)
        return output

    mnist_X = mnist_X.reshape(-1, 28, 28)
    mnist_X = flatten_img(mnist_X) # X.shape => (n_samples, seq_len)
    mnist_X = mnist_X[:, :, np.newaxis] # X.shape => (n_samples, seq_len, n_features)

    # ????????????
    train_X, test_X, train_y, test_y = train_test_split(mnist_X, mnist_y,
                                                        test_size=0.2,
                                                        random_state=seed)
    return train_X, test_X, train_y, test_y
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号