python类ToTensor()的实例源码

neural_style.py 文件源码 项目:examples 作者: pytorch 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def stylize(args):
    content_image = utils.load_image(args.content_image, scale=args.content_scale)
    content_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0)
    if args.cuda:
        content_image = content_image.cuda()
    content_image = Variable(content_image, volatile=True)

    style_model = TransformerNet()
    style_model.load_state_dict(torch.load(args.model))
    if args.cuda:
        style_model.cuda()
    output = style_model(content_image)
    if args.cuda:
        output = output.cpu()
    output_data = output.data[0]
    utils.save_image(args.output_image, output_data)
saliency.py 文件源码 项目:DeepLearning_PlantDiseases 作者: MarkoArsenovic 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def load_labels(data_dir,resize=(224,224)):

    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomSizedCrop(max(resize)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    }
    dsets = {x: datasets.ImageFolder(os.path.join(data_dir, 'train'), data_transforms[x])
             for x in ['train']}  
    return (dsets['train'].classes)
base_dataset.py 文件源码 项目:DeblurGAN 作者: KupynOrest 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def get_transform(opt):
    transform_list = []
    if opt.resize_or_crop == 'resize_and_crop':
        osize = [opt.loadSizeX, opt.loadSizeY]
        transform_list.append(transforms.Scale(osize, Image.BICUBIC))
        transform_list.append(transforms.RandomCrop(opt.fineSize))
    elif opt.resize_or_crop == 'crop':
        transform_list.append(transforms.RandomCrop(opt.fineSize))
    elif opt.resize_or_crop == 'scale_width':
        transform_list.append(transforms.Lambda(
            lambda img: __scale_width(img, opt.fineSize)))
    elif opt.resize_or_crop == 'scale_width_and_crop':
        transform_list.append(transforms.Lambda(
            lambda img: __scale_width(img, opt.loadSizeX)))
        transform_list.append(transforms.RandomCrop(opt.fineSize))

    if opt.isTrain and not opt.no_flip:
        transform_list.append(transforms.RandomHorizontalFlip())

    transform_list += [transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5),
                                            (0.5, 0.5, 0.5))]
    return transforms.Compose(transform_list)
data_loader.py 文件源码 项目:mnist-svhn-transfer 作者: yunjey 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def get_loader(config):
    """Builds and returns Dataloader for MNIST and SVHN dataset."""

    transform = transforms.Compose([
                    transforms.Scale(config.image_size),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    svhn = datasets.SVHN(root=config.svhn_path, download=True, transform=transform)
    mnist = datasets.MNIST(root=config.mnist_path, download=True, transform=transform)

    svhn_loader = torch.utils.data.DataLoader(dataset=svhn,
                                              batch_size=config.batch_size,
                                              shuffle=True,
                                              num_workers=config.num_workers)

    mnist_loader = torch.utils.data.DataLoader(dataset=mnist,
                                               batch_size=config.batch_size,
                                               shuffle=True,
                                               num_workers=config.num_workers)
    return svhn_loader, mnist_loader
mnist.py 文件源码 项目:pytorch-adda 作者: corenel 项目源码 文件源码 阅读 35 收藏 0 点赞 0 评论 0
def get_mnist(train):
    """Get MNIST dataset loader."""
    # image pre-processing
    pre_process = transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize(
                                          mean=params.dataset_mean,
                                          std=params.dataset_std)])

    # dataset and data loader
    mnist_dataset = datasets.MNIST(root=params.data_root,
                                   train=train,
                                   transform=pre_process,
                                   download=True)

    mnist_data_loader = torch.utils.data.DataLoader(
        dataset=mnist_dataset,
        batch_size=params.batch_size,
        shuffle=True)

    return mnist_data_loader
train.py 文件源码 项目:torch_light 作者: ne7ermore 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def dataLoader(is_train=True, cuda=True, batch_size=64, shuffle=True):
        if is_train:
            trans = [transforms.RandomHorizontalFlip(),
                     transforms.RandomCrop(32, padding=4),
                     transforms.ToTensor(),
                     transforms.Normalize(mean=[n/255.
                        for n in [129.3, 124.1, 112.4]], std=[n/255. for n in [68.2,  65.4,  70.4]])]
            trans = transforms.Compose(trans)
            train_set = td.CIFAR100('data', train=True, transform=trans)
            size = len(train_set.train_labels)
            train_loader = torch.utils.data.DataLoader(
                            train_set, batch_size=batch_size, shuffle=shuffle)
        else:
            trans = [transforms.ToTensor(),
                     transforms.Normalize(mean=[n/255.
                        for n in [129.3, 124.1, 112.4]], std=[n/255. for n in [68.2,  65.4,  70.4]])]
            trans = transforms.Compose(trans)
            test_set = td.CIFAR100('data', train=False, transform=trans)
            size = len(test_set.test_labels)
            train_loader = torch.utils.data.DataLoader(
                            test_set, batch_size=batch_size, shuffle=shuffle)

        return train_loader, size
util.py 文件源码 项目:pyro 作者: uber 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def get_data_loader(dataset_name,
                    batch_size=1,
                    dataset_transforms=None,
                    is_training_set=True,
                    shuffle=True):
    if not dataset_transforms:
        dataset_transforms = []
    trans = transforms.Compose([transforms.ToTensor()] + dataset_transforms)
    dataset = getattr(datasets, dataset_name)
    return DataLoader(
        dataset(root=DATA_DIR,
                train=is_training_set,
                transform=trans,
                download=True),
        batch_size=batch_size,
        shuffle=shuffle
    )
dataloader.py 文件源码 项目:ExperimentPackage_PyTorch 作者: ICEORY 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def mnist(self):
        norm_mean = [0.1307]
        norm_std = [0.3081]
        train_loader = torch.utils.data.DataLoader(
            dsets.MNIST("/home/dataset/mnist", train=True, download=True,
                        transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize(norm_mean, norm_std)
                        ])),
            batch_size=self.train_batch_size, shuffle=True,
            num_workers=self.n_threads,
            pin_memory=False
        )
        test_loader = torch.utils.data.DataLoader(
            dsets.MNIST("/home/dataset/mnist", train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(norm_mean, norm_std)
            ])),
            batch_size=self.test_batch_size, shuffle=True,
            num_workers=self.n_threads,
            pin_memory=False
        )
        return train_loader, test_loader
test_preprocessor.py 文件源码 项目:open-reid 作者: Cysu 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def test_getitem(self):
        import torchvision.transforms as t
        from reid.datasets.viper import VIPeR
        from reid.utils.data.preprocessor import Preprocessor

        root, split_id, num_val = '/tmp/open-reid/viper', 0, 100
        dataset = VIPeR(root, split_id=split_id, num_val=num_val, download=True)

        preproc = Preprocessor(dataset.train, root=dataset.images_dir,
                               transform=t.Compose([
                                   t.Scale(256),
                                   t.CenterCrop(224),
                                   t.ToTensor(),
                                   t.Normalize(mean=[0.485, 0.456, 0.406],
                                               std=[0.229, 0.224, 0.225])
                               ]))
        self.assertEquals(len(preproc), len(dataset.train))
        img, pid, camid = preproc[0]
        self.assertEquals(img.size(), (3, 224, 224))
usps.py 文件源码 项目:pytorch-arda 作者: corenel 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def get_usps(train):
    """Get USPS dataset loader."""
    # image pre-processing
    pre_process = transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize(
                                          mean=params.dataset_mean,
                                          std=params.dataset_std)])

    # dataset and data loader
    usps_dataset = USPS(root=params.data_root,
                        train=train,
                        transform=pre_process,
                        download=True)

    usps_data_loader = torch.utils.data.DataLoader(
        dataset=usps_dataset,
        batch_size=params.batch_size,
        shuffle=True)

    return usps_data_loader
main.py 文件源码 项目:SimGAN_pytorch 作者: AlexHex7 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def load_data(self):
        print('=' * 50)
        print('Loading data...')
        transform = transforms.Compose([
            transforms.ImageOps.grayscale,
            transforms.Scale((cfg.img_width, cfg.img_height)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        syn_train_folder = torchvision.datasets.ImageFolder(root=cfg.syn_path, transform=transform)
        # print(syn_train_folder)
        self.syn_train_loader = Data.DataLoader(syn_train_folder, batch_size=cfg.batch_size, shuffle=True,
                                                pin_memory=True)
        print('syn_train_batch %d' % len(self.syn_train_loader))

        real_folder = torchvision.datasets.ImageFolder(root=cfg.real_path, transform=transform)
        # real_folder.imgs = real_folder.imgs[:2000]
        self.real_loader = Data.DataLoader(real_folder, batch_size=cfg.batch_size, shuffle=True,
                                           pin_memory=True)
        print('real_batch %d' % len(self.real_loader))
spatial_dataloader.py 文件源码 项目:two-stream-action-recognition 作者: jeffreyhuang1 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def train(self):
        training_set = spatial_dataset(dic=self.dic_training, root_dir=self.data_path, mode='train', transform = transforms.Compose([
                transforms.RandomCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
                ]))
        print '==> Training data :',len(training_set),'frames'
        print training_set[1][0]['img1'].size()

        train_loader = DataLoader(
            dataset=training_set, 
            batch_size=self.BATCH_SIZE,
            shuffle=True,
            num_workers=self.num_workers)
        return train_loader
spatial_dataloader.py 文件源码 项目:two-stream-action-recognition 作者: jeffreyhuang1 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def validate(self):
        validation_set = spatial_dataset(dic=self.dic_testing, root_dir=self.data_path, mode='val', transform = transforms.Compose([
                transforms.Scale([224,224]),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
                ]))

        print '==> Validation data :',len(validation_set),'frames'
        print validation_set[1][1].size()

        val_loader = DataLoader(
            dataset=validation_set, 
            batch_size=self.BATCH_SIZE, 
            shuffle=False,
            num_workers=self.num_workers)
        return val_loader
motion_dataloader.py 文件源码 项目:two-stream-action-recognition 作者: jeffreyhuang1 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def train(self):
        training_set = motion_dataset(dic=self.dic_video_train, in_channel=self.in_channel, root_dir=self.data_path,
            mode='train',
            transform = transforms.Compose([
            transforms.Scale([224,224]),
            transforms.ToTensor(),
            ]))
        print '==> Training data :',len(training_set),' videos',training_set[1][0].size()

        train_loader = DataLoader(
            dataset=training_set, 
            batch_size=self.BATCH_SIZE,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True
            )

        return train_loader
motion_dataloader.py 文件源码 项目:two-stream-action-recognition 作者: jeffreyhuang1 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def val(self):
        validation_set = motion_dataset(dic= self.dic_test_idx, in_channel=self.in_channel, root_dir=self.data_path ,
            mode ='val',
            transform = transforms.Compose([
            transforms.Scale([224,224]),
            transforms.ToTensor(),
            ]))
        print '==> Validation data :',len(validation_set),' frames',validation_set[1][1].size()
        #print validation_set[1]

        val_loader = DataLoader(
            dataset=validation_set, 
            batch_size=self.BATCH_SIZE, 
            shuffle=False,
            num_workers=self.num_workers)

        return val_loader
Translator.py 文件源码 项目:NeuralMT 作者: hlt-mt 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def buildData(self, srcBatch, goldBatch):
        # This needs to be the same as preprocess.py.
        if self._type == "text":
            srcData = [self.src_dict.convertToIdx(b,
                                                  onmt.Constants.UNK_WORD)
                       for b in srcBatch]
        elif self._type == "img":
            srcData = [transforms.ToTensor()(
                Image.open(self.opt.src_img_dir + "/" + b[0]))
                       for b in srcBatch]

        tgtData = None
        if goldBatch:
            tgtData = [self.tgt_dict.convertToIdx(b,
                       onmt.Constants.UNK_WORD,
                       onmt.Constants.BOS_WORD,
                       onmt.Constants.EOS_WORD) for b in goldBatch]

        return onmt.Dataset(srcData, tgtData, self.opt.batch_size,
                            self.opt.cuda, volatile=True,
                            data_type=self._type)
engine.py 文件源码 项目:wildcat.pytorch 作者: durandtibo 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def init_learning(self, model, criterion):

        if self._state('train_transform') is None:
            normalize = transforms.Normalize(mean=model.image_normalization_mean,
                                             std=model.image_normalization_std)
            self.state['train_transform'] = transforms.Compose([
                Warp(self.state['image_size']),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])

        if self._state('val_transform') is None:
            normalize = transforms.Normalize(mean=model.image_normalization_mean,
                                             std=model.image_normalization_std)
            self.state['val_transform'] = transforms.Compose([
                Warp(self.state['image_size']),
                transforms.ToTensor(),
                normalize,
            ])

        self.state['best_score'] = 0
usps.py 文件源码 项目:pytorch-adda 作者: corenel 项目源码 文件源码 阅读 131 收藏 0 点赞 0 评论 0
def get_usps(train):
    """Get USPS dataset loader."""
    # image pre-processing
    pre_process = transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize(
                                          mean=params.dataset_mean,
                                          std=params.dataset_std)])

    # dataset and data loader
    usps_dataset = USPS(root=params.data_root,
                        train=train,
                        transform=pre_process,
                        download=True)

    usps_data_loader = torch.utils.data.DataLoader(
        dataset=usps_dataset,
        batch_size=params.batch_size,
        shuffle=True)

    return usps_data_loader
mnist_hogwild.py 文件源码 项目:ml-utils 作者: LinxiFan 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def train(rank, args, model):
    torch.manual_seed(args.seed + rank)

    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=args.batch_size, shuffle=True, num_workers=1)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])),
        batch_size=args.batch_size, shuffle=True, num_workers=1)

    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
    for epoch in range(1, args.epochs + 1):
        train_epoch(epoch, args, model, train_loader, optimizer)
        test_epoch(model, test_loader)
datasets.py 文件源码 项目:mean-teacher 作者: CuriousAI 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def imagenet():
    channel_stats = dict(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
    train_transformation = data.TransformTwice(transforms.Compose([
        transforms.RandomRotation(10),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(**channel_stats)
    ]))
    eval_transformation = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(**channel_stats)
    ])

    return {
        'train_transformation': train_transformation,
        'eval_transformation': eval_transformation,
        'datadir': 'data-local/images/ilsvrc2012/',
        'num_classes': 1000
    }
datasets.py 文件源码 项目:mean-teacher 作者: CuriousAI 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def cifar10():
    channel_stats = dict(mean=[0.4914, 0.4822, 0.4465],
                         std=[0.2470,  0.2435,  0.2616])
    train_transformation = data.TransformTwice(transforms.Compose([
        data.RandomTranslateWithReflect(4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(**channel_stats)
    ]))
    eval_transformation = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(**channel_stats)
    ])

    return {
        'train_transformation': train_transformation,
        'eval_transformation': eval_transformation,
        'datadir': 'data-local/images/cifar/cifar10/by-image',
        'num_classes': 10
    }
coco_caption.py 文件源码 项目:seq2seq.pytorch 作者: eladhoffer 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def imagenet_transform(scale_size=256, input_size=224, train=True, allow_var_size=False):
    normalize = {'mean': [0.485, 0.456, 0.406],
                 'std': [0.229, 0.224, 0.225]}

    if train:
        return transforms.Compose([
            transforms.Scale(scale_size),
            transforms.RandomCrop(input_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(**normalize)
        ])
    elif allow_var_size:
        return transforms.Compose([
            transforms.Scale(scale_size),
            transforms.ToTensor(),
            transforms.Normalize(**normalize)
        ])
    else:
        return transforms.Compose([
            transforms.Scale(scale_size),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            transforms.Normalize(**normalize)
        ])
pose_estimator.py 文件源码 项目:DeepPoseComparison 作者: ynaka81 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def __init__(self, Nj, gpu, model_file, filename):
        # validate arguments.
        self.gpu = (gpu >= 0)
        if self.gpu and not torch.cuda.is_available():
            raise GPUNotFoundError('GPU is not found.')
        # initialize model to estimate.
        self.model = AlexNet(Nj)
        self.model.load_state_dict(torch.load(model_file))
        # prepare gpu.
        if self.gpu:
            self.model.cuda()
        # load dataset to estimate.
        self.dataset = PoseDataset(
            filename,
            input_transform=transforms.Compose([
                transforms.ToTensor(),
                RandomNoise()]),
            output_transform=Scale(),
            transform=Crop(data_augmentation=True))
core_process.py 文件源码 项目:DeepPoseComparison 作者: ynaka81 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def __init__(self, Nj, gpu, model_file, filename):
        # validate arguments.
        self.gpu = (gpu >= 0)
        if self.gpu and not torch.cuda.is_available():
            raise GPUNotFoundError('GPU is not found.')
        # initialize model to estimate.
        self.model = AlexNet(Nj)
        self.model.load_state_dict(torch.load(model_file))
        # prepare gpu.
        if self.gpu:
            self.model.cuda()
        # load dataset to estimate.
        self.dataset = PoseDataset(
            filename,
            input_transform=transforms.Compose([
                transforms.ToTensor(),
                RandomNoise()]),
            output_transform=Scale(),
            transform=Crop())
base_dataset.py 文件源码 项目:CycleGANwithPerceptionLoss 作者: EliasVansteenkiste 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def get_transform(opt):
    transform_list = []
    if opt.resize_or_crop == 'resize_and_crop':
        osize = [opt.loadSize, opt.loadSize]
        transform_list.append(transforms.Scale(osize, Image.BICUBIC))
        transform_list.append(transforms.RandomCrop(opt.fineSize))
    elif opt.resize_or_crop == 'crop':
        transform_list.append(transforms.RandomCrop(opt.fineSize))
    elif opt.resize_or_crop == 'scale_width':
        transform_list.append(transforms.Lambda(
            lambda img: __scale_width(img, opt.fineSize)))
    elif opt.resize_or_crop == 'scale_width_and_crop':
        transform_list.append(transforms.Lambda(
            lambda img: __scale_width(img, opt.loadSize)))
        transform_list.append(transforms.RandomCrop(opt.fineSize))

    if opt.isTrain and not opt.no_flip:
        transform_list.append(transforms.RandomHorizontalFlip())

    transform_list += [transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5),
                                            (0.5, 0.5, 0.5))]
    return transforms.Compose(transform_list)
utils.py 文件源码 项目:autodiff 作者: bgavran 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def __init__(self, batch_size):
        self.batch_size = batch_size

        train_dataset = dsets.MNIST(root="./data",
                                    train=True,
                                    transform=transforms.ToTensor(),
                                    download=True)
        test_dataset = dsets.MNIST(root="./data",
                                   train=False,
                                   transform=transforms.ToTensor())

        self.train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                                        batch_size=batch_size,
                                                        shuffle=True)

        self.test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                                       batch_size=batch_size,
                                                       shuffle=True)
unaligned_data_loader.py 文件源码 项目:pytorch_cycle_gan 作者: jinfagang 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def load_image_for_prediction(opt, image_path):
    """
    load image for prediction, pre process as the same of train, and also return a dict
    :param opt:
    :param image_path:
    :return:
    """
    image = Image.open(image_path)
    transformations = transforms.Compose([transforms.Scale(opt.loadSize),
                                          transforms.RandomCrop(opt.fineSize),
                                          transforms.ToTensor(),
                                          transforms.Normalize((0.5, 0.5, 0.5),
                                                               (0.5, 0.5, 0.5))])
    image_tensor = transformations(image).float()
    image_tensor.unsqueeze_(0)
    return {'A': image_tensor, 'A_paths': image_path,
            'B': image_tensor, 'B_paths': image_path}
utils.py 文件源码 项目:sourceseparation_misc 作者: ycemsubakan 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def get_loaders(loader_batchsize, **kwargs):
    arguments=kwargs['arguments']
    data = arguments.data

    if data == 'mnist':
        kwargs = {'num_workers': 1, 'pin_memory': True} if arguments.cuda else {}
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST('../data', train=True, download=True,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               #transforms.Normalize((0,), (1,))
                           ])),
            batch_size=loader_batchsize, shuffle=True, **kwargs)
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST('../data', train=False, transform=transforms.Compose([
                               transforms.ToTensor(),
                               #transforms.Normalize((7,), (0.3081,))
                           ])),
            batch_size=loader_batchsize, shuffle=True, **kwargs)

    return train_loader, test_loader
data_utilities.py 文件源码 项目:generative_zoo 作者: DL-IT 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def LSUN_loader(root, image_size, classes=['bedroom'], normalize=True):
    """
        Function to load torchvision dataset object based on just image size
        Args:
            root        = If your dataset is downloaded and ready to use, mention the location of this folder. Else, the dataset will be downloaded to this location
            image_size  = Size of every image
            classes     = Default class is 'bedroom'. Other available classes are:
                        'bridge', 'church_outdoor', 'classroom', 'conference_room', 'dining_room', 'kitchen', 'living_room', 'restaurant', 'tower'
            normalize   = Requirement to normalize the image. Default is true
    """
    transformations = [transforms.Scale(image_size), transforms.CenterCrop(image_size), transforms.ToTensor()]
    if normalize == True:
        transformations.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
    for c in classes:
        c   = c + '_train'
    lsun_data   = dset.LSUN(db_path=root, classes=classes, transform=transforms.Compose(transformations))
    return lsun_data
base_dataset.py 文件源码 项目:pytorch-CycleGAN-and-pix2pix 作者: junyanz 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def get_transform(opt):
    transform_list = []
    if opt.resize_or_crop == 'resize_and_crop':
        osize = [opt.loadSize, opt.loadSize]
        transform_list.append(transforms.Scale(osize, Image.BICUBIC))
        transform_list.append(transforms.RandomCrop(opt.fineSize))
    elif opt.resize_or_crop == 'crop':
        transform_list.append(transforms.RandomCrop(opt.fineSize))
    elif opt.resize_or_crop == 'scale_width':
        transform_list.append(transforms.Lambda(
            lambda img: __scale_width(img, opt.fineSize)))
    elif opt.resize_or_crop == 'scale_width_and_crop':
        transform_list.append(transforms.Lambda(
            lambda img: __scale_width(img, opt.loadSize)))
        transform_list.append(transforms.RandomCrop(opt.fineSize))

    if opt.isTrain and not opt.no_flip:
        transform_list.append(transforms.RandomHorizontalFlip())

    transform_list += [transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5),
                                            (0.5, 0.5, 0.5))]
    return transforms.Compose(transform_list)


问题


面经


文章

微信
公众号

扫码关注公众号