def __init__(self, opt):
transform_list = []
if (opt.crop_height > 0) and (opt.crop_width > 0):
transform_list.append(transforms.CenterCrop(opt.crop_height, crop_width))
elif opt.crop_size > 0:
transform_list.append(transforms.CenterCrop(opt.crop_size))
transform_list.append(transforms.Scale(opt.image_size))
transform_list.append(transforms.CenterCrop(opt.image_size))
transform_list.append(transforms.ToTensor())
if opt.dataset == 'cifar10':
dataset1 = datasets.CIFAR10(root = opt.dataroot, download = True,
transform = transforms.Compose(transform_list))
dataset2 = datasets.CIFAR10(root = opt.dataroot, train = False,
transform = transforms.Compose(transform_list))
def get_data(k):
if k < len(dataset1):
return dataset1[k][0]
else:
return dataset2[k - len(dataset1)][0]
else:
if opt.dataset in ['imagenet', 'folder', 'lfw']:
dataset = datasets.ImageFolder(root = opt.dataroot,
transform = transforms.Compose(transform_list))
elif opt.dataset == 'lsun':
dataset = datasets.LSUN(db_path = opt.dataroot, classes = [opt.lsun_class + '_train'],
transform = transforms.Compose(transform_list))
def get_data(k):
return dataset[k][0]
data_index = torch.load(os.path.join(opt.dataroot, 'data_index.pt'))
train_index = data_index['train']
self.opt = opt
self.get_data = get_data
self.train_index = data_index['train']
self.counter = 0
calculate_inception_scores.py 文件源码
python
阅读 22
收藏 0
点赞 0
评论 0
评论列表
文章目录