def getLoader(datasetName, dataroot, originalSize, imageSize, batchSize=64, workers=4,
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), split='train', transform_fn=None):
import torchvision.transforms as transforms
if transform_fn is None and (split=='train' or split=='extra'):
transform_fn = transforms.Compose([transforms.Scale(originalSize),
transforms.RandomCrop(imageSize),
transforms.ToTensor(),
transforms.Normalize(mean, std),
])
elif transform_fn is None and split=='test':
transform_fn = transforms.Compose([transforms.Scale(imageSize),
transforms.ToTensor(),
transforms.Normalize(mean, std),
])
if datasetName == 'svhn':
from torchvision.datasets.svhn import SVHN as commonDataset
if split=='train': split = 'extra'
dataset = commonDataset(root=dataroot,
download=True,
split=split,
transform=transform_fn)
elif datasetName == 'mnist':
from torchvision.datasets.mnist import MNIST as commonDataset
flag_trn = split=='train'
dataset = commonDataset(root=dataroot,
download=True,
train=flag_trn,
transform=transform_fn)
assert dataset
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=batchSize,
shuffle=True,
num_workers=int(workers))
return dataloader, dataset
评论列表
文章目录