mnist.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:latplan 作者: guicho271828 项目源码 文件源码
def mnist (labels = range(10)):
    from keras.datasets import mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = (x_train.astype('float32') / 255.).round()
    x_test = (x_test.astype('float32') / 255.).round()
    x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
    x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))
    def conc (x,y):
        return np.concatenate((y.reshape([len(y),1]),x),axis=1)
    def select (x,y):
        selected = np.array([elem for elem in conc(x, y) if elem[0] in labels])
        return np.delete(selected,0,1), np.delete(selected,np.s_[1::],1).flatten()
    x_train, y_train = select(x_train, y_train)
    x_test, y_test = select(x_test, y_test)
    return x_train, y_train, x_test, y_test
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号