def get_iterator(mode,opt):
if (opt.imagenetpath is None):
raise (RuntimeError('Where is imagenet?'))
if (opt.N is None):
raise (RuntimeError('Crop size not provided'))
if (opt.batchSize is None):
raise (RuntimeError('Batch Size not provided '))
if (opt.nthread is None):
raise (RuntimeError('num threads?'))
def cvload(path):
img = cv2.imread(path, cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
convert = tnt.transform.compose([
lambda x: x.astype(np.float32) / 255.0,
cvtransforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
lambda x: x.transpose(2, 0, 1).astype(np.float32),
torch.from_numpy,
])
print("| setting up data loader...")
if mode:
traindir = os.path.join(opt.imagenetpath, 'train')
if (opt.max_samples > 0):
ds = datasmall.ImageFolder(traindir, tnt.transform.compose([
cvtransforms.RandomSizedCrop(opt.N),
cvtransforms.RandomHorizontalFlip(),
convert,
]), loader=cvload,maxSamp=opt.max_samples)
else:
ds =torchvision.datasets.ImageFolder(traindir, tnt.transform.compose([
cvtransforms.RandomSizedCrop(opt.N),
cvtransforms.RandomHorizontalFlip(),
convert,
]), loader=cvload)
else:
if opt.N==224:
crop_scale=256
else:
crop_scale=256*opt.N/224
valdir = os.path.join(opt.imagenetpath, 'val')
ds = torchvision.datasets.ImageFolder(valdir, tnt.transform.compose([
cvtransforms.Scale(crop_scale),
cvtransforms.CenterCrop(opt.N),
convert,
]), loader=cvload)
return torch.utils.data.DataLoader(ds,
batch_size=opt.batchSize, shuffle=mode,
num_workers=opt.nthread, pin_memory=False)
评论列表
文章目录