misc.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:DomainTransferNetwork.pytorch 作者: taey16 项目源码 文件源码
def getLoader(datasetName, dataroot, originalSize, imageSize, batchSize=64, workers=4,
              mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), split='train', transform_fn=None):
  import torchvision.transforms as transforms
  if transform_fn is None and (split=='train' or split=='extra'):
    transform_fn = transforms.Compose([transforms.Scale(originalSize),
                                        transforms.RandomCrop(imageSize),
                                        transforms.ToTensor(),
                                        transforms.Normalize(mean, std),
                                      ])
  elif transform_fn is None and split=='test':
    transform_fn = transforms.Compose([transforms.Scale(imageSize),
                                       transforms.ToTensor(),
                                       transforms.Normalize(mean, std),
                                      ])

  if datasetName == 'svhn':
    from torchvision.datasets.svhn import SVHN as commonDataset
    if split=='train': split = 'extra'
    dataset = commonDataset(root=dataroot, 
                            download=True, 
                            split=split, 
                            transform=transform_fn)
  elif datasetName == 'mnist':
    from torchvision.datasets.mnist import MNIST as commonDataset
    flag_trn = split=='train'
    dataset = commonDataset(root=dataroot, 
                            download=True, 
                            train=flag_trn, 
                            transform=transform_fn)

  assert dataset
  dataloader = torch.utils.data.DataLoader(dataset, 
                                           batch_size=batchSize, 
                                           shuffle=True, 
                                           num_workers=int(workers))
  return dataloader, dataset
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号