def LSUN_loader(root, image_size, classes=['bedroom'], normalize=True):
"""
Function to load torchvision dataset object based on just image size
Args:
root = If your dataset is downloaded and ready to use, mention the location of this folder. Else, the dataset will be downloaded to this location
image_size = Size of every image
classes = Default class is 'bedroom'. Other available classes are:
'bridge', 'church_outdoor', 'classroom', 'conference_room', 'dining_room', 'kitchen', 'living_room', 'restaurant', 'tower'
normalize = Requirement to normalize the image. Default is true
"""
transformations = [transforms.Scale(image_size), transforms.CenterCrop(image_size), transforms.ToTensor()]
if normalize == True:
transformations.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
for c in classes:
c = c + '_train'
lsun_data = dset.LSUN(db_path=root, classes=classes, transform=transforms.Compose(transformations))
return lsun_data
python类LSUN的实例源码
def __init__(self, opt):
transform_list = []
if (opt.crop_height > 0) and (opt.crop_width > 0):
transform_list.append(transforms.CenterCrop(opt.crop_height, crop_width))
elif opt.crop_size > 0:
transform_list.append(transforms.CenterCrop(opt.crop_size))
transform_list.append(transforms.Scale(opt.image_size))
transform_list.append(transforms.CenterCrop(opt.image_size))
transform_list.append(transforms.ToTensor())
if opt.dataset == 'cifar10':
dataset1 = datasets.CIFAR10(root = opt.dataroot, download = True,
transform = transforms.Compose(transform_list))
dataset2 = datasets.CIFAR10(root = opt.dataroot, train = False,
transform = transforms.Compose(transform_list))
def get_data(k):
if k < len(dataset1):
return dataset1[k][0]
else:
return dataset2[k - len(dataset1)][0]
else:
if opt.dataset in ['imagenet', 'folder', 'lfw']:
dataset = datasets.ImageFolder(root = opt.dataroot,
transform = transforms.Compose(transform_list))
elif opt.dataset == 'lsun':
dataset = datasets.LSUN(db_path = opt.dataroot, classes = [opt.lsun_class + '_train'],
transform = transforms.Compose(transform_list))
def get_data(k):
return dataset[k][0]
data_index = torch.load(os.path.join(opt.dataroot, 'data_index.pt'))
train_index = data_index['train']
self.opt = opt
self.get_data = get_data
self.train_index = data_index['train']
self.counter = 0
def get_dataloader(opt):
if opt.dataset in ['imagenet', 'folder', 'lfw']:
# folder dataset
dataset = dset.ImageFolder(root=opt.dataroot,
transform=transforms.Compose([
transforms.Scale(opt.imageScaleSize),
transforms.CenterCrop(opt.imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5)),
]))
elif opt.dataset == 'lsun':
dataset = dset.LSUN(db_path=opt.dataroot, classes=['bedroom_train'],
transform=transforms.Compose([
transforms.Scale(opt.imageScaleSize),
transforms.CenterCrop(opt.imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5)),
]))
elif opt.dataset == 'cifar10':
dataset = dset.CIFAR10(root=opt.dataroot, download=True,
transform=transforms.Compose([
transforms.Scale(opt.imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5)),
])
)
assert dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size,
shuffle=True,
num_workers=int(opt.workers))
return dataloader
def get_data(args, train_flag=True):
transform = transforms.Compose([
transforms.Scale(args.image_size),
transforms.CenterCrop(args.image_size),
transforms.ToTensor(),
transforms.Normalize(
(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
if args.dataset in ['imagenet', 'folder', 'lfw']:
dataset = dset.ImageFolder(root=args.dataroot,
transform=transform)
elif args.dataset == 'lsun':
dataset = dset.LSUN(db_path=args.dataroot,
classes=['bedroom_train'],
transform=transform)
elif args.dataset == 'cifar10':
dataset = dset.CIFAR10(root=args.dataroot,
download=True,
train=train_flag,
transform=transform)
elif args.dataset == 'cifar100':
dataset = dset.CIFAR100(root=args.dataroot,
download=True,
train=train_flag,
transform=transform)
elif args.dataset == 'mnist':
dataset = dset.MNIST(root=args.dataroot,
download=True,
train=train_flag,
transform=transform)
elif args.dataset == 'celeba':
imdir = 'train' if train_flag else 'val'
dataroot = os.path.join(args.dataroot, imdir)
if args.image_size != 64:
raise ValueError('the image size for CelebA dataset need to be 64!')
dataset = FolderWithImages(root=dataroot,
input_transform=transforms.Compose([
ALICropAndScale(),
transforms.ToTensor(),
transforms.Normalize(
(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]),
target_transform=transforms.ToTensor()
)
else:
raise ValueError("Unknown dataset %s" % (args.dataset))
return dataset