data_loader.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:DenseNet 作者: kevinzakka 项目源码 文件源码
def get_test_loader(data_dir,
                    name,
                    batch_size,
                    shuffle=True,
                    num_workers=4,
                    pin_memory=False):
    """
    Utility function for loading and returning a multi-process 
    test iterator over the CIFAR-10 dataset.

    If using CUDA, num_workers should be set to 1 and pin_memory to True.

    Params
    ------
    - data_dir: path directory to the dataset.
    - name: string specifying which dataset to load. Can be `cifar10`,
      or `cifar100`.
    - batch_size: how many samples per batch to load.
    - shuffle: whether to shuffle the dataset after every epoch.
    - num_workers: number of subprocesses to use when loading the dataset.
    - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
      True if using GPU.

    Returns
    -------
    - data_loader: test set iterator.
    """
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    # define transform
    transform = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])

    if name == 'cifar10':
        dataset = datasets.CIFAR10(root=data_dir, 
                                   train=False, 
                                   download=True,
                                   transform=transform)
    else:
        dataset = datasets.CIFAR100(root=data_dir, 
                                    train=False, 
                                    download=True,
                                    transform=transform)

    data_loader = torch.utils.data.DataLoader(dataset, 
                                              batch_size=batch_size, 
                                              shuffle=shuffle, 
                                              num_workers=num_workers,
                                              pin_memory=pin_memory)

    return data_loader
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号