def create_dataset(opt, mode):
convert = tnt.transform.compose([
lambda x: x.astype(np.float32),
T.Normalize([125.3, 123.0, 113.9], [63.0, 62.1, 66.7]),
lambda x: x.transpose(2,0,1).astype(np.float32),
torch.from_numpy,
])
train_transform = tnt.transform.compose([
T.RandomHorizontalFlip(),
T.Pad(opt.randomcrop_pad, cv2.BORDER_REFLECT),
T.RandomCrop(32),
convert,
])
ds = getattr(datasets, opt.dataset)(opt.data_root, train=mode, download=True)
smode = 'train' if mode else 'test'
ds = tnt.dataset.TensorDataset([
getattr(ds, smode+'_data'),
getattr(ds, smode+'_labels')])
return ds.transform({0: train_transform if mode else convert})
评论列表
文章目录