datasets.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:generative_models 作者: j-min 项目源码 文件源码
def get_dataset(config):
    """Return dataset class"""

    torchvision_datasets = [
        'LSUN',
        'CocoCaptions',
        'CocoDetection',
        'CIFAR10',
        'CIFAR100',
        'FashionMNIST',
        'MNIST',
        'STL10',
        'SVHN',
        'PhotoTour',
        'SEMEION']

    # unaligned_datasets = [
    #     'horse2zebra'
    # ]

    if config.dataset in torchvision_datasets:
        dataset = getattr(datasets, config.dataset)(
            root=config.dataset_dir,
            train=config.isTrain,
            download=True,
            transform=base_transform(config))
    else:
        dataset = get_custom_dataset(config)
    return dataset
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号