cifar_prepare.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def loadData(src):
    print('Downloading ' + src)
    fname, h = urlretrieve(src, './delete.me')
    print('Done.')
    try:
        print('Extracting files...')
        with tarfile.open(fname) as tar:
            tar.extractall()
        print('Done.')
        print('Preparing train set...')
        trn = np.empty((0, numFeature + 1), dtype=np.int)
        for i in range(5):
            batchName = './cifar-10-batches-py/data_batch_{0}'.format(i + 1)
            trn = np.vstack((trn, readBatch(batchName)))
        print('Done.')
        print('Preparing test set...')
        tst = readBatch('./cifar-10-batches-py/test_batch')
        print('Done.')
    finally:
        os.remove(fname)
    return (trn, tst)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号