def get_usps(train):
"""Get USPS dataset loader."""
# image pre-processing
pre_process = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(
mean=params.dataset_mean,
std=params.dataset_std)])
# dataset and data loader
usps_dataset = USPS(root=params.data_root,
train=train,
transform=pre_process,
download=True)
usps_data_loader = torch.utils.data.DataLoader(
dataset=usps_dataset,
batch_size=params.batch_size,
shuffle=True)
return usps_data_loader
评论列表
文章目录