def get_loaders(args):
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
if args.dataset == 'mnist':
trainLoader = torch.utils.data.DataLoader(
dset.MNIST('data/mnist', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batchSz, shuffle=True, **kwargs)
testLoader = torch.utils.data.DataLoader(
dset.MNIST('data/mnist', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batchSz, shuffle=False, **kwargs)
elif args.dataset == 'cifar-10':
normMean = [0.49139968, 0.48215827, 0.44653124]
normStd = [0.24703233, 0.24348505, 0.26158768]
normTransform = transforms.Normalize(normMean, normStd)
trainTransform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normTransform
])
testTransform = transforms.Compose([
transforms.ToTensor(),
normTransform
])
trainLoader = DataLoader(
dset.CIFAR10(root='data/cifar', train=True, download=True,
transform=trainTransform),
batch_size=args.batchSz, shuffle=True, **kwargs)
testLoader = DataLoader(
dset.CIFAR10(root='data/cifar', train=False, download=True,
transform=testTransform),
batch_size=args.batchSz, shuffle=False, **kwargs)
else:
assert(False)
return trainLoader, testLoader
评论列表
文章目录