def get_dataset(config):
"""Return dataset class"""
torchvision_datasets = [
'LSUN',
'CocoCaptions',
'CocoDetection',
'CIFAR10',
'CIFAR100',
'FashionMNIST',
'MNIST',
'STL10',
'SVHN',
'PhotoTour',
'SEMEION']
# unaligned_datasets = [
# 'horse2zebra'
# ]
if config.dataset in torchvision_datasets:
dataset = getattr(datasets, config.dataset)(
root=config.dataset_dir,
train=config.isTrain,
download=True,
transform=base_transform(config))
else:
dataset = get_custom_dataset(config)
return dataset
评论列表
文章目录