def create_dataset(opt, mode):
convert = tnt.transform.compose([
lambda x: x.astype(np.float32),
lambda x: x / 255.0,
# cvtransforms.Normalize([125.3, 123.0, 113.9], [63.0, 62.1, 66.7]),
lambda x: x.transpose(2, 0, 1).astype(np.float32),
torch.from_numpy,
])
train_transform = tnt.transform.compose([
cvtransforms.RandomHorizontalFlip(),
cvtransforms.Pad(opt.randomcrop_pad, cv2.BORDER_REFLECT),
cvtransforms.RandomCrop(32),
convert,
])
ds = getattr(datasets, opt.dataset)('.', train=mode, download=True)
smode = 'train' if mode else 'test'
if mode:
from numpy.random import RandomState
prng = RandomState(opt.seed)
assert(opt.sampleSize%10==0)
random_permute=prng.permutation(np.arange(0,5000))[0:opt.sampleSize/10]
labels = np.array(getattr(ds,'train_labels'))
data = getattr(ds,'train_data')
classes=np.unique(labels)
inds_all=np.array([],dtype='int32')
for cl in classes:
inds=np.where(np.array(labels)==cl)[0][random_permute]
inds_all=np.r_[inds,inds_all]
ds = tnt.dataset.TensorDataset([
data[inds_all,:].transpose(0, 2, 3, 1),
labels[inds_all].tolist()])
else:
ds = tnt.dataset.TensorDataset([
getattr(ds, smode + '_data').transpose(0, 2, 3, 1),
getattr(ds, smode + '_labels')])
return ds.transform({0: train_transform if mode else convert})
main_small_sample_class_normalized.py 文件源码
python
阅读 18
收藏 0
点赞 0
评论 0
评论列表
文章目录