def create_iterator(opt, mode):
if opt.dataset.startswith('CIFAR'):
convert = tnt.transform.compose([
lambda x: x.astype(np.float32),
T.Normalize([125.3, 123.0, 113.9], [63.0, 62.1, 66.7]),
lambda x: x.transpose(2,0,1),
torch.from_numpy,
])
train_transform = tnt.transform.compose([
T.RandomHorizontalFlip(),
T.Pad(opt.randomcrop_pad, cv2.BORDER_REFLECT),
T.RandomCrop(32),
convert,
])
ds = getattr(datasets, opt.dataset)(opt.dataroot, train=mode, download=True)
smode = 'train' if mode else 'test'
ds = tnt.dataset.TensorDataset([getattr(ds, smode + '_data'),
getattr(ds, smode + '_labels')])
ds = ds.transform({0: train_transform if mode else convert})
return ds.parallel(batch_size=opt.batchSize, shuffle=mode,
num_workers=opt.nthread, pin_memory=True)
elif opt.dataset == 'ImageNet':
def cvload(path):
img = cv2.imread(path, cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
convert = tnt.transform.compose([
lambda x: x.astype(np.float32) / 255.0,
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
lambda x: x.transpose(2, 0, 1).astype(np.float32),
torch.from_numpy,
])
print("| setting up data loader...")
if mode:
traindir = os.path.join(opt.dataroot, 'train')
ds = datasets.ImageFolder(traindir, tnt.transform.compose([
T.RandomSizedCrop(224),
T.RandomHorizontalFlip(),
convert,
]), loader=cvload)
else:
valdir = os.path.join(opt.dataroot, 'val')
ds = datasets.ImageFolder(valdir, tnt.transform.compose([
T.Scale(256),
T.CenterCrop(224),
convert,
]), loader=cvload)
return torch.utils.data.DataLoader(ds,
batch_size=opt.batchSize, shuffle=mode,
num_workers=opt.nthread, pin_memory=False)
else:
raise ValueError('dataset not understood')
评论列表
文章目录