def getLoader(datasetName, dataroot, originalSize, imageSize, batchSize=64, workers=4,
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), split='train', shuffle=True, seed=None):
#import pdb; pdb.set_trace()
if datasetName == 'trans':
from datasets.trans import trans as commonDataset
import transforms.pix2pix as transforms
elif datasetName == 'folder':
from torchvision.datasets.folder import ImageFolder as commonDataset
import torchvision.transforms as transforms
elif datasetName == 'pix2pix':
from datasets.pix2pix import pix2pix as commonDataset
import transforms.pix2pix as transforms
if datasetName != 'folder':
if split == 'train':
dataset = commonDataset(root=dataroot,
transform=transforms.Compose([
transforms.Scale(originalSize),
transforms.RandomCrop(imageSize),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean, std),
]),
seed=seed)
else:
dataset = commonDataset(root=dataroot,
transform=transforms.Compose([
transforms.Scale(originalSize),
transforms.CenterCrop(imageSize),
transforms.ToTensor(),
transforms.Normalize(mean, std),
]),
seed=seed)
else:
if split == 'train':
dataset = commonDataset(root=dataroot,
transform=transforms.Compose([
transforms.Scale(originalSize),
transforms.RandomCrop(imageSize),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean, std),
]))
else:
dataset = commonDataset(root=dataroot,
transform=transforms.Compose([
transforms.Scale(originalSize),
transforms.CenterCrop(imageSize),
transforms.ToTensor(),
transforms.Normalize(mean, std),
]))
assert dataset
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=batchSize,
shuffle=shuffle,
num_workers=int(workers))
return dataloader
评论列表
文章目录