get_datasets.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:metaqnn 作者: bowenbaker 项目源码 文件源码
def get_caltech101(save_dir=None, root_path=None):
    assert((save_dir is not None and root_path is None) or (save_dir is None and root_path is not None))

    if root_path is None:
        ctx = ssl.create_default_context()
        ctx.check_hostname = False
        ctx.verify_mode = ssl.CERT_NONE

        print 'Downloading Caltech101 dataset...'
        tar_path = os.path.join(save_dir, "101_ObjectCategories.tar.gz")
        url = urllib.URLopener(context=ctx)
        url.retrieve("https://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz", tar_path)
        print 'Download Done, Extracting...'
        tar = tarfile.open(tar_path)
        tar.extractall(save_dir)
        tar.close()

    root = os.path.join(save_dir, "101_ObjectCategories") if not root_path else root_path

    train_x = []
    train_y = []
    val_x = []
    val_y = []

    label = 0
    for cls_folder in os.listdir(root):
        cls_root = os.path.join(root, cls_folder)
        if not os.path.isdir(cls_root):
            continue

        cls_images = [misc.imread(os.path.join(cls_root, img_name)) for img_name in os.listdir(cls_root)]
        cls_images = [np.repeat(np.expand_dims(img, 2), 3, axis=2) if len(img.shape) == 2 else img for img in cls_images]
        cls_images = np.array([np.reshape(misc.imresize(img, (224,224,3)), (3,224,224)) for img in cls_images])
        new_index = np.random.permutation(np.arange(cls_images.shape[0]))
        cls_images = cls_images[new_index, :, :, :]

        train_x.append(cls_images[:30])
        train_y.append(np.array([label]*30))
        if len(cls_images) <= 80:
            val_x.append(cls_images[30:])
            val_y.append(np.array([label]*(len(cls_images)-30)))
        else:
            val_x.append(cls_images[30:80])
            val_y.append(np.array([label]*50))
        label += 1

    Xtr = np.concatenate(train_x)
    Ytr = np.concatenate(train_y)
    Xval= np.concatenate(val_x)
    Yval= np.concatenate(val_y)

    print 'Xtr shape ', Xtr.shape
    print 'Ytr shape ', Ytr.shape
    print 'Xval shape ', Xval.shape
    print 'Yval shape ', Yval.shape

    return Xtr, Ytr, Xval, Yval
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号