def get_caltech101(save_dir=None, root_path=None):
assert((save_dir is not None and root_path is None) or (save_dir is None and root_path is not None))
if root_path is None:
ctx = ssl.create_default_context()
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
print 'Downloading Caltech101 dataset...'
tar_path = os.path.join(save_dir, "101_ObjectCategories.tar.gz")
url = urllib.URLopener(context=ctx)
url.retrieve("https://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz", tar_path)
print 'Download Done, Extracting...'
tar = tarfile.open(tar_path)
tar.extractall(save_dir)
tar.close()
root = os.path.join(save_dir, "101_ObjectCategories") if not root_path else root_path
train_x = []
train_y = []
val_x = []
val_y = []
label = 0
for cls_folder in os.listdir(root):
cls_root = os.path.join(root, cls_folder)
if not os.path.isdir(cls_root):
continue
cls_images = [misc.imread(os.path.join(cls_root, img_name)) for img_name in os.listdir(cls_root)]
cls_images = [np.repeat(np.expand_dims(img, 2), 3, axis=2) if len(img.shape) == 2 else img for img in cls_images]
cls_images = np.array([np.reshape(misc.imresize(img, (224,224,3)), (3,224,224)) for img in cls_images])
new_index = np.random.permutation(np.arange(cls_images.shape[0]))
cls_images = cls_images[new_index, :, :, :]
train_x.append(cls_images[:30])
train_y.append(np.array([label]*30))
if len(cls_images) <= 80:
val_x.append(cls_images[30:])
val_y.append(np.array([label]*(len(cls_images)-30)))
else:
val_x.append(cls_images[30:80])
val_y.append(np.array([label]*50))
label += 1
Xtr = np.concatenate(train_x)
Ytr = np.concatenate(train_y)
Xval= np.concatenate(val_x)
Yval= np.concatenate(val_y)
print 'Xtr shape ', Xtr.shape
print 'Ytr shape ', Ytr.shape
print 'Xval shape ', Xval.shape
print 'Yval shape ', Yval.shape
return Xtr, Ytr, Xval, Yval
评论列表
文章目录