misc.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:pix2pix.pytorch 作者: taey16 项目源码 文件源码
def getLoader(datasetName, dataroot, originalSize, imageSize, batchSize=64, workers=4,
              mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), split='train', shuffle=True, seed=None):

  #import pdb; pdb.set_trace()
  if datasetName == 'trans':
    from datasets.trans import trans as commonDataset
    import transforms.pix2pix as transforms
  elif datasetName == 'folder':
    from torchvision.datasets.folder import ImageFolder as commonDataset
    import torchvision.transforms as transforms
  elif datasetName == 'pix2pix':
    from datasets.pix2pix import pix2pix as commonDataset
    import transforms.pix2pix as transforms

  if datasetName != 'folder':
    if split == 'train':
      dataset = commonDataset(root=dataroot,
                            transform=transforms.Compose([
                              transforms.Scale(originalSize),
                              transforms.RandomCrop(imageSize),
                              transforms.RandomHorizontalFlip(),
                              transforms.ToTensor(),
                              transforms.Normalize(mean, std),
                            ]),
                            seed=seed)
    else:
      dataset = commonDataset(root=dataroot,
                            transform=transforms.Compose([
                              transforms.Scale(originalSize),
                              transforms.CenterCrop(imageSize),
                              transforms.ToTensor(),
                              transforms.Normalize(mean, std),
                             ]),
                             seed=seed)

  else:
    if split == 'train':
      dataset = commonDataset(root=dataroot,
                            transform=transforms.Compose([
                              transforms.Scale(originalSize),
                              transforms.RandomCrop(imageSize),
                              transforms.RandomHorizontalFlip(),
                              transforms.ToTensor(),
                              transforms.Normalize(mean, std),
                            ]))
    else:
      dataset = commonDataset(root=dataroot,
                            transform=transforms.Compose([
                              transforms.Scale(originalSize),
                              transforms.CenterCrop(imageSize),
                              transforms.ToTensor(),
                              transforms.Normalize(mean, std),
                             ]))
  assert dataset
  dataloader = torch.utils.data.DataLoader(dataset, 
                                           batch_size=batchSize, 
                                           shuffle=shuffle, 
                                           num_workers=int(workers))
  return dataloader
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号