utils.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:scalingscattering 作者: edouardoyallon 项目源码 文件源码
def get_iterator(mode,opt):
    if (opt.imagenetpath is None):
        raise (RuntimeError('Where is imagenet?'))
    if (opt.N is None):
        raise (RuntimeError('Crop size not provided'))
    if (opt.batchSize is None):
        raise (RuntimeError('Batch Size not provided '))
    if (opt.nthread is None):
        raise (RuntimeError('num threads?'))


    def cvload(path):
        img = cv2.imread(path, cv2.IMREAD_COLOR)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        return img


    convert = tnt.transform.compose([
        lambda x: x.astype(np.float32) / 255.0,
        cvtransforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225]),
        lambda x: x.transpose(2, 0, 1).astype(np.float32),
        torch.from_numpy,
    ])




    print("| setting up data loader...")
    if mode:
        traindir = os.path.join(opt.imagenetpath, 'train')
        if (opt.max_samples > 0):
            ds = datasmall.ImageFolder(traindir, tnt.transform.compose([
                cvtransforms.RandomSizedCrop(opt.N),
                cvtransforms.RandomHorizontalFlip(),
                convert,
            ]), loader=cvload,maxSamp=opt.max_samples)
        else:
            ds =torchvision.datasets.ImageFolder(traindir, tnt.transform.compose([
            cvtransforms.RandomSizedCrop(opt.N),
            cvtransforms.RandomHorizontalFlip(),
            convert,
            ]), loader=cvload)
    else:
        if opt.N==224:
            crop_scale=256
        else:
            crop_scale=256*opt.N/224

        valdir = os.path.join(opt.imagenetpath, 'val')
        ds = torchvision.datasets.ImageFolder(valdir, tnt.transform.compose([
            cvtransforms.Scale(crop_scale),
            cvtransforms.CenterCrop(opt.N),
            convert,
        ]), loader=cvload)



    return torch.utils.data.DataLoader(ds,
                                       batch_size=opt.batchSize, shuffle=mode,
                                       num_workers=opt.nthread, pin_memory=False)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号