def install(
self, local_dst_dir_=None, local_src_dir_=None, clean_install_=False):
'''
Install the dataset into directly usable format,
requires downloading for public dataset.
Args:
local_dst_dir_: string or None
where to install the dataset, None -> "%(default_dir)s"
local_src_dir_: string or None
where to find the raw downloaded files, None -> "%(default_dir)s"
'''
local_dst_dir = self.DEFAULT_DIR if local_dst_dir_ is None else Path(local_dst_dir_)
local_src_dir = self.DEFAULT_DIR if local_src_dir_ is None else Path(local_src_dir_)
local_dst_dir.mkdir(parents=True, exist_ok=True)
assert local_src_dir.exists()
images = np.empty((60000,3,32,32), dtype=np.uint8)
labels = np.empty((60000,), dtype=np.uint8)
tarfile_name = str(local_src_dir / 'cifar-10-python.tar.gz')
with tarfile.open(tarfile_name, 'r:gz') as tf:
for i in range(5):
with tf.extractfile('cifar-10-batches-py/data_batch_%d'%(i+1)) as f:
data_di = pickle.load(f, encoding='bytes')
images[(10000*i):(10000*(i+1))] = data_di[b'data'].reshape((10000,3,32,32))
labels[(10000*i):(10000*(i+1))] = np.asarray(data_di[b'labels'], dtype=np.uint8)
with tf.extractfile('cifar-10-batches-py/test_batch') as f:
data_di = pickle.load(f, encoding='bytes')
images[50000:60000] = data_di[b'data'].reshape((10000,3,32,32))
labels[50000:60000] = data_di[b'labels']
np.savez_compressed(str(local_dst_dir / 'cifar10.npz'), images=images, labels=labels)
if clean_install_:
os.remove(tarfile_name)
评论列表
文章目录