mnist.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:ddnn 作者: kunglab 项目源码 文件源码
def get_mnist():
    train, test = chainer.datasets.get_mnist(ndim=3)

    train_data = [t for t in train]
    test_data = [t for t in test]

    train_data = np.array(train_data)
    test_data = np.array(test_data)

    train_data = np.expand_dims(train_data, 1)
    test_data = np.expand_dims(test_data, 1)

    train_xs = train_data[:,:,0].T
    train_ys = train_data[:,:,1].T

    test_xs = test_data[:,:,0].T
    test_ys = test_data[:,:,1].T

    train = TupleDataset(*(train_xs.tolist() + train_ys.tolist()))
    test = TupleDataset(*(test_xs.tolist() + test_ys.tolist()))

    return train,test
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号