def get_loader(config):
"""Builds and returns Dataloader for MNIST and SVHN dataset."""
transform = transforms.Compose([
transforms.Scale(config.image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
svhn = datasets.SVHN(root=config.svhn_path, download=True, transform=transform)
mnist = datasets.MNIST(root=config.mnist_path, download=True, transform=transform)
svhn_loader = torch.utils.data.DataLoader(dataset=svhn,
batch_size=config.batch_size,
shuffle=True,
num_workers=config.num_workers)
mnist_loader = torch.utils.data.DataLoader(dataset=mnist,
batch_size=config.batch_size,
shuffle=True,
num_workers=config.num_workers)
return svhn_loader, mnist_loader
python类SVHN的实例源码
def get_loader(config):
"""Builds and returns Dataloader for MNIST and SVHN dataset."""
transform = transforms.Compose([
transforms.Scale(config.image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
svhn = datasets.SVHN(root=config.svhn_path, download=True, transform=transform, split='train')
mnist = datasets.MNIST(root=config.mnist_path, download=True, transform=transform, train=True)
svhn_test = datasets.SVHN(root=config.svhn_path, download=True, transform=transform, split='test')
mnist_test = datasets.MNIST(root=config.mnist_path, download=True, transform=transform, train=False)
svhn_loader = torch.utils.data.DataLoader(dataset=svhn,
batch_size=config.batch_size,
shuffle=True,
num_workers=config.num_workers)
mnist_loader = torch.utils.data.DataLoader(dataset=mnist,
batch_size=config.batch_size,
shuffle=True,
num_workers=config.num_workers)
svhn_test_loader = torch.utils.data.DataLoader(dataset=svhn_test,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers)
mnist_test_loader = torch.utils.data.DataLoader(dataset=mnist_test,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers)
return svhn_loader, mnist_loader, svhn_test_loader, mnist_test_loader
def get(batch_size, data_root='/tmp/public_dataset/pytorch', train=True, val=True, **kwargs):
data_root = os.path.expanduser(os.path.join(data_root, 'svhn-data'))
num_workers = kwargs.setdefault('num_workers', 1)
kwargs.pop('input_size', None)
print("Building SVHN data loader with {} workers".format(num_workers))
def target_transform(target):
return int(target[0]) - 1
ds = []
if train:
train_loader = torch.utils.data.DataLoader(
datasets.SVHN(
root=data_root, split='train', download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]),
target_transform=target_transform,
),
batch_size=batch_size, shuffle=True, **kwargs)
ds.append(train_loader)
if val:
test_loader = torch.utils.data.DataLoader(
datasets.SVHN(
root=data_root, split='test', download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]),
target_transform=target_transform
),
batch_size=batch_size, shuffle=False, **kwargs)
ds.append(test_loader)
ds = ds[0] if len(ds) == 1 else ds
return ds
def __init__(self, batchsize, train=True):
Dataset.__init__(self)
data_root = join(dirname(realpath(__file__)), 'SVHN_data')
self.name = "svhn"
self.range = [0.0, 1.0]
self.data_dims = [3, 32, 32]
self.batchsize = batchsize
if train:
split = "train"
self.data = dsets.SVHN(root=data_root,
download=True,
split="train",
transform=transforms.Compose([
transforms.ToTensor()]))
self.dataloder = tdata.DataLoader(self.data, self.batchsize, shuffle=True)
self.iter = iter(self.dataloder)
self._index = 0