python类SVHN的实例源码

data_loader.py 文件源码 项目:mnist-svhn-transfer 作者: yunjey 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def get_loader(config):
    """Builds and returns Dataloader for MNIST and SVHN dataset."""

    transform = transforms.Compose([
                    transforms.Scale(config.image_size),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    svhn = datasets.SVHN(root=config.svhn_path, download=True, transform=transform)
    mnist = datasets.MNIST(root=config.mnist_path, download=True, transform=transform)

    svhn_loader = torch.utils.data.DataLoader(dataset=svhn,
                                              batch_size=config.batch_size,
                                              shuffle=True,
                                              num_workers=config.num_workers)

    mnist_loader = torch.utils.data.DataLoader(dataset=mnist,
                                               batch_size=config.batch_size,
                                               shuffle=True,
                                               num_workers=config.num_workers)
    return svhn_loader, mnist_loader
data_loader.py 文件源码 项目:DistanceGAN 作者: sagiebenaim 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def get_loader(config):
    """Builds and returns Dataloader for MNIST and SVHN dataset."""

    transform = transforms.Compose([
                    transforms.Scale(config.image_size),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    svhn = datasets.SVHN(root=config.svhn_path, download=True, transform=transform, split='train')
    mnist = datasets.MNIST(root=config.mnist_path, download=True, transform=transform, train=True)

    svhn_test = datasets.SVHN(root=config.svhn_path, download=True, transform=transform, split='test')
    mnist_test = datasets.MNIST(root=config.mnist_path, download=True, transform=transform, train=False)

    svhn_loader = torch.utils.data.DataLoader(dataset=svhn,
                                              batch_size=config.batch_size,
                                              shuffle=True,
                                              num_workers=config.num_workers)

    mnist_loader = torch.utils.data.DataLoader(dataset=mnist,
                                               batch_size=config.batch_size,
                                               shuffle=True,
                                               num_workers=config.num_workers)


    svhn_test_loader = torch.utils.data.DataLoader(dataset=svhn_test,
                                              batch_size=config.batch_size,
                                              shuffle=False,
                                              num_workers=config.num_workers)

    mnist_test_loader = torch.utils.data.DataLoader(dataset=mnist_test,
                                               batch_size=config.batch_size,
                                               shuffle=False,
                                               num_workers=config.num_workers)

    return svhn_loader, mnist_loader, svhn_test_loader, mnist_test_loader
dataset.py 文件源码 项目:pytorch-playground 作者: aaron-xichen 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def get(batch_size, data_root='/tmp/public_dataset/pytorch', train=True, val=True, **kwargs):
    data_root = os.path.expanduser(os.path.join(data_root, 'svhn-data'))
    num_workers = kwargs.setdefault('num_workers', 1)
    kwargs.pop('input_size', None)
    print("Building SVHN data loader with {} workers".format(num_workers))

    def target_transform(target):
        return int(target[0]) - 1

    ds = []
    if train:
        train_loader = torch.utils.data.DataLoader(
            datasets.SVHN(
                root=data_root, split='train', download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]),
                target_transform=target_transform,
            ),
            batch_size=batch_size, shuffle=True, **kwargs)
        ds.append(train_loader)

    if val:
        test_loader = torch.utils.data.DataLoader(
            datasets.SVHN(
                root=data_root, split='test', download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]),
                target_transform=target_transform
            ),
            batch_size=batch_size, shuffle=False, **kwargs)
        ds.append(test_loader)
    ds = ds[0] if len(ds) == 1 else ds
    return ds
dataset_SVHN.py 文件源码 项目:DisentangleVAE 作者: Jueast 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def __init__(self, batchsize, train=True):
        Dataset.__init__(self)
        data_root = join(dirname(realpath(__file__)), 'SVHN_data')
        self.name = "svhn"
        self.range = [0.0, 1.0]
        self.data_dims = [3, 32, 32]
        self.batchsize = batchsize
        if train:
            split = "train"
        self.data = dsets.SVHN(root=data_root,
                           download=True,
                           split="train",
                           transform=transforms.Compose([
                                transforms.ToTensor()]))
        self.dataloder = tdata.DataLoader(self.data, self.batchsize, shuffle=True)
        self.iter = iter(self.dataloder)
        self._index = 0


问题


面经


文章

微信
公众号

扫码关注公众号