def setup_data_loaders(dataset, use_cuda, batch_size, sup_num=None, root='./data', download=True, **kwargs):
"""
helper function for setting up pytorch data loaders for a semi-supervised dataset
:param dataset: the data to use
:param use_cuda: use GPU(s) for training
:param batch_size: size of a batch of data to output when iterating over the data loaders
:param sup_num: number of supervised data examples
:param root: where on the filesystem should the dataset be
:param download: download the dataset (if it doesn't exist already)
:param kwargs: other params for the pytorch data loader
:return: three data loaders: (supervised data for training, un-supervised data for training,
supervised data for testing)
"""
# instantiate the dataset as training/testing sets
if 'num_workers' not in kwargs:
kwargs = {'num_workers': 0, 'pin_memory': False}
cached_data = {}
loaders = {}
for mode in ["unsup", "test", "sup", "valid"]:
if sup_num is None and mode == "sup":
# in this special case, we do not want "sup" and "valid" data loaders
return loaders["unsup"], loaders["test"]
cached_data[mode] = dataset(root=root, mode=mode, download=download,
sup_num=sup_num, use_cuda=use_cuda)
loaders[mode] = DataLoader(cached_data[mode], batch_size=batch_size, shuffle=True, **kwargs)
return loaders
评论列表
文章目录