def dataLoader(is_train=True, cuda=True, batch_size=64, shuffle=True):
if is_train:
trans = [transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize(mean=[n/255.
for n in [129.3, 124.1, 112.4]], std=[n/255. for n in [68.2, 65.4, 70.4]])]
trans = transforms.Compose(trans)
train_set = td.CIFAR100('data', train=True, transform=trans)
size = len(train_set.train_labels)
train_loader = torch.utils.data.DataLoader(
train_set, batch_size=batch_size, shuffle=shuffle)
else:
trans = [transforms.ToTensor(),
transforms.Normalize(mean=[n/255.
for n in [129.3, 124.1, 112.4]], std=[n/255. for n in [68.2, 65.4, 70.4]])]
trans = transforms.Compose(trans)
test_set = td.CIFAR100('data', train=False, transform=trans)
size = len(test_set.test_labels)
train_loader = torch.utils.data.DataLoader(
test_set, batch_size=batch_size, shuffle=shuffle)
return train_loader, size
python类CIFAR100的实例源码
def get_dataset(name, split='train', transform=None,
target_transform=None, download=True, datasets_path=__DATASETS_DEFAULT_PATH):
train = (split == 'train')
root = os.path.join(datasets_path, name)
if name == 'cifar10':
return datasets.CIFAR10(root=root,
train=train,
transform=transform,
target_transform=target_transform,
download=download)
elif name == 'cifar100':
return datasets.CIFAR100(root=root,
train=train,
transform=transform,
target_transform=target_transform,
download=download)
elif name == 'mnist':
return datasets.MNIST(root=root,
train=train,
transform=transform,
target_transform=target_transform,
download=download)
elif name == 'stl10':
return datasets.STL10(root=root,
split=split,
transform=transform,
target_transform=target_transform,
download=download)
elif name == 'imagenet':
if train:
root = os.path.join(root, 'train')
else:
root = os.path.join(root, 'val')
return datasets.ImageFolder(root=root,
transform=transform,
target_transform=target_transform)
def get_dataset(name, split='train', transform=None,
target_transform=None, download=True, datasets_path=__DATASETS_DEFAULT_PATH):
train = (split == 'train')
root = os.path.join(datasets_path, name)
if name == 'cifar10':
return datasets.CIFAR10(root=root,
train=train,
transform=transform,
target_transform=target_transform,
download=download)
elif name == 'cifar100':
return datasets.CIFAR100(root=root,
train=train,
transform=transform,
target_transform=target_transform,
download=download)
elif name == 'mnist':
return datasets.MNIST(root=root,
train=train,
transform=transform,
target_transform=target_transform,
download=download)
elif name == 'stl10':
return datasets.STL10(root=root,
split=split,
transform=transform,
target_transform=target_transform,
download=download)
elif name == 'imagenet':
if train:
root = os.path.join(root, 'train')
else:
root = os.path.join(root, 'val')
return datasets.ImageFolder(root=root,
transform=transform,
target_transform=target_transform)
def get100(batch_size, data_root='/tmp/public_dataset/pytorch', train=True, val=True, **kwargs):
data_root = os.path.expanduser(os.path.join(data_root, 'cifar100-data'))
num_workers = kwargs.setdefault('num_workers', 1)
kwargs.pop('input_size', None)
print("Building CIFAR-100 data loader with {} workers".format(num_workers))
ds = []
if train:
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR100(
root=data_root, train=True, download=True,
transform=transforms.Compose([
transforms.Pad(4),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])),
batch_size=batch_size, shuffle=True, **kwargs)
ds.append(train_loader)
if val:
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR100(
root=data_root, train=False, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])),
batch_size=batch_size, shuffle=False, **kwargs)
ds.append(test_loader)
ds = ds[0] if len(ds) == 1 else ds
return ds
def get_test_loader(data_dir,
name,
batch_size,
shuffle=True,
num_workers=4,
pin_memory=False):
"""
Utility function for loading and returning a multi-process
test iterator over the CIFAR-10 dataset.
If using CUDA, num_workers should be set to 1 and pin_memory to True.
Params
------
- data_dir: path directory to the dataset.
- name: string specifying which dataset to load. Can be `cifar10`,
or `cifar100`.
- batch_size: how many samples per batch to load.
- shuffle: whether to shuffle the dataset after every epoch.
- num_workers: number of subprocesses to use when loading the dataset.
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
True if using GPU.
Returns
-------
- data_loader: test set iterator.
"""
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
# define transform
transform = transforms.Compose([
transforms.ToTensor(),
normalize
])
if name == 'cifar10':
dataset = datasets.CIFAR10(root=data_dir,
train=False,
download=True,
transform=transform)
else:
dataset = datasets.CIFAR100(root=data_dir,
train=False,
download=True,
transform=transform)
data_loader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=pin_memory)
return data_loader
def get_data(args, train_flag=True):
transform = transforms.Compose([
transforms.Scale(args.image_size),
transforms.CenterCrop(args.image_size),
transforms.ToTensor(),
transforms.Normalize(
(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
if args.dataset in ['imagenet', 'folder', 'lfw']:
dataset = dset.ImageFolder(root=args.dataroot,
transform=transform)
elif args.dataset == 'lsun':
dataset = dset.LSUN(db_path=args.dataroot,
classes=['bedroom_train'],
transform=transform)
elif args.dataset == 'cifar10':
dataset = dset.CIFAR10(root=args.dataroot,
download=True,
train=train_flag,
transform=transform)
elif args.dataset == 'cifar100':
dataset = dset.CIFAR100(root=args.dataroot,
download=True,
train=train_flag,
transform=transform)
elif args.dataset == 'mnist':
dataset = dset.MNIST(root=args.dataroot,
download=True,
train=train_flag,
transform=transform)
elif args.dataset == 'celeba':
imdir = 'train' if train_flag else 'val'
dataroot = os.path.join(args.dataroot, imdir)
if args.image_size != 64:
raise ValueError('the image size for CelebA dataset need to be 64!')
dataset = FolderWithImages(root=dataroot,
input_transform=transforms.Compose([
ALICropAndScale(),
transforms.ToTensor(),
transforms.Normalize(
(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]),
target_transform=transforms.ToTensor()
)
else:
raise ValueError("Unknown dataset %s" % (args.dataset))
return dataset