def load_data(self):
print('=' * 50)
print('Loading data...')
transform = transforms.Compose([
transforms.ImageOps.grayscale,
transforms.Scale((cfg.img_width, cfg.img_height)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
syn_train_folder = torchvision.datasets.ImageFolder(root=cfg.syn_path, transform=transform)
# print(syn_train_folder)
self.syn_train_loader = Data.DataLoader(syn_train_folder, batch_size=cfg.batch_size, shuffle=True,
pin_memory=True)
print('syn_train_batch %d' % len(self.syn_train_loader))
real_folder = torchvision.datasets.ImageFolder(root=cfg.real_path, transform=transform)
# real_folder.imgs = real_folder.imgs[:2000]
self.real_loader = Data.DataLoader(real_folder, batch_size=cfg.batch_size, shuffle=True,
pin_memory=True)
print('real_batch %d' % len(self.real_loader))
评论列表
文章目录