calculate_inception_scores.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:gan-error-avoidance 作者: aleju 项目源码 文件源码
def __init__(self, opt):
        transform_list = []

        if (opt.crop_height > 0) and (opt.crop_width > 0):
            transform_list.append(transforms.CenterCrop(opt.crop_height, crop_width))
        elif opt.crop_size > 0:
            transform_list.append(transforms.CenterCrop(opt.crop_size))

        transform_list.append(transforms.Scale(opt.image_size))
        transform_list.append(transforms.CenterCrop(opt.image_size))

        transform_list.append(transforms.ToTensor())

        if opt.dataset == 'cifar10':
            dataset1 = datasets.CIFAR10(root = opt.dataroot, download = True,
                transform = transforms.Compose(transform_list))
            dataset2 = datasets.CIFAR10(root = opt.dataroot, train = False,
                transform = transforms.Compose(transform_list))
            def get_data(k):
                if k < len(dataset1):
                    return dataset1[k][0]
                else:
                    return dataset2[k - len(dataset1)][0]
        else:
            if opt.dataset in ['imagenet', 'folder', 'lfw']:
                dataset = datasets.ImageFolder(root = opt.dataroot,
                    transform = transforms.Compose(transform_list))
            elif opt.dataset == 'lsun':
                dataset = datasets.LSUN(db_path = opt.dataroot, classes = [opt.lsun_class + '_train'],
                    transform = transforms.Compose(transform_list))
            def get_data(k):
                return dataset[k][0]

        data_index = torch.load(os.path.join(opt.dataroot, 'data_index.pt'))
        train_index = data_index['train']

        self.opt = opt
        self.get_data = get_data
        self.train_index = data_index['train']
        self.counter = 0
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号