def get_test_loader(data_dir,
name,
batch_size,
shuffle=True,
num_workers=4,
pin_memory=False):
"""
Utility function for loading and returning a multi-process
test iterator over the CIFAR-10 dataset.
If using CUDA, num_workers should be set to 1 and pin_memory to True.
Params
------
- data_dir: path directory to the dataset.
- name: string specifying which dataset to load. Can be `cifar10`,
or `cifar100`.
- batch_size: how many samples per batch to load.
- shuffle: whether to shuffle the dataset after every epoch.
- num_workers: number of subprocesses to use when loading the dataset.
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
True if using GPU.
Returns
-------
- data_loader: test set iterator.
"""
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
# define transform
transform = transforms.Compose([
transforms.ToTensor(),
normalize
])
if name == 'cifar10':
dataset = datasets.CIFAR10(root=data_dir,
train=False,
download=True,
transform=transform)
else:
dataset = datasets.CIFAR100(root=data_dir,
train=False,
download=True,
transform=transform)
data_loader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=pin_memory)
return data_loader
评论列表
文章目录