def get_data_loader(dataset_name,
batch_size=1,
dataset_transforms=None,
is_training_set=True,
shuffle=True):
if not dataset_transforms:
dataset_transforms = []
trans = transforms.Compose([transforms.ToTensor()] + dataset_transforms)
dataset = getattr(datasets, dataset_name)
return DataLoader(
dataset(root=DATA_DIR,
train=is_training_set,
transform=trans,
download=True),
batch_size=batch_size,
shuffle=shuffle
)
评论列表
文章目录