util.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def get_data_loader(dataset_name,
                    batch_size=1,
                    dataset_transforms=None,
                    is_training_set=True,
                    shuffle=True):
    if not dataset_transforms:
        dataset_transforms = []
    trans = transforms.Compose([transforms.ToTensor()] + dataset_transforms)
    dataset = getattr(datasets, dataset_name)
    return DataLoader(
        dataset(root=DATA_DIR,
                train=is_training_set,
                transform=trans,
                download=True),
        batch_size=batch_size,
        shuffle=shuffle
    )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号