def get_cifar100(save_dir=None, root_path=None):
''' If root_path is None, we download the data set from internet.
Either save path or root path must not be None and not both.
Returns Xtr, Ytr, Xte, Yte as numpy arrays
'''
assert((save_dir is not None and root_path is None) or (save_dir is None and root_path is not None))
if root_path is None:
print 'Downloading CIFAR100 dataset...'
tar_path = os.path.join(save_dir, "cifar-100-python.tar.gz")
url = urllib.URLopener()
url.retrieve("https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz", tar_path)
print 'Download Done, Extracting...'
tar = tarfile.open(tar_path)
tar.extractall(save_dir)
tar.close()
root = os.path.join(save_dir, "cifar-100-python") if not root_path else root_path
Xtr, Ytr = load_cifar100_data(os.path.join(root, 'train'))
Xte, Yte = load_cifar100_data(os.path.join(root, 'test'))
print 'Xtrain shape', Xtr.shape
print 'Ytrain shape', Ytr.shape
print 'Xtest shape', Xte.shape
print 'Ytest shape', Yte.shape
return Xtr, Ytr, Xte, Yte
评论列表
文章目录