data_loader.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:mxnet_tk1 作者: starimpact 项目源码 文件源码
def load_mnist(training_num=50000):
    data_path = os.path.join(os.path.dirname(os.path.realpath('__file__')), 'mnist.npz')
    if not os.path.isfile(data_path):
        from six.moves import urllib
        origin = (
            'https://github.com/sxjscience/mxnet/raw/master/example/bayesian-methods/mnist.npz'
        )
        print 'Downloading data from %s to %s' % (origin, data_path)
        context = ssl._create_unverified_context()
        urllib.request.urlretrieve(origin, data_path, context=context)
        print 'Done!'
    dat = numpy.load(data_path)
    X = (dat['X'][:training_num] / 126.0).astype('float32')
    Y = dat['Y'][:training_num]
    X_test = (dat['X_test'] / 126.0).astype('float32')
    Y_test = dat['Y_test']
    Y = Y.reshape((Y.shape[0],))
    Y_test = Y_test.reshape((Y_test.shape[0],))
    return X, Y, X_test, Y_test
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号