def create_dataset(opt, mode,fold=0):
convert = tnt.transform.compose([
lambda x: x.astype(np.float32),
lambda x: x / 255.0,
# cvtransforms.Normalize([125.3, 123.0, 113.9], [63.0, 62.1, 66.7]),
lambda x: x.transpose(2, 0, 1).astype(np.float32),
torch.from_numpy,
])
train_transform = tnt.transform.compose([
cvtransforms.RandomHorizontalFlip(),
cvtransforms.Pad(opt.randomcrop_pad, cv2.BORDER_REFLECT),
cvtransforms.RandomCrop(96),
convert,
])
smode = 'train' if mode else 'test'
ds = getattr(datasets, opt.dataset)('.', split=smode, download=True)
if mode:
if fold>-1:
folds_idx = [map(int, v.split(' ')[:-1])
for v in [line.replace('\n', '')
for line in open('./stl10_binary/fold_indices.txt')]][fold]
ds = tnt.dataset.TensorDataset([
getattr(ds, 'data').transpose(0, 2, 3, 1)[folds_idx],
getattr(ds, 'labels')[folds_idx].tolist()])
else:
ds = tnt.dataset.TensorDataset([
getattr(ds, 'data').transpose(0, 2, 3, 1),
getattr(ds, 'labels').tolist()])
else:
ds = tnt.dataset.TensorDataset([
getattr(ds, 'data').transpose(0, 2, 3, 1),
getattr(ds, 'labels').tolist()])
return ds.transform({0: train_transform if mode else convert})
评论列表
文章目录