calculate_inception_scores.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:gan-error-avoidance 作者: aleju 项目源码 文件源码
def _generate_images(self, nb_batches, g_fp, r_idx, opt, show_info, queue):
        import torch
        import torch.nn as nn
        import torch.optim as optim
        import torchvision
        import torchvision.datasets as datasets
        import torchvision.transforms as transforms
        from torch.autograd import Variable

        #np.random.seed(42)
        #random.seed(42)
        #torch.manual_seed(42)

        gen = GeneratorLearnedInputSpace(opt.width, opt.height, opt.nfeature, opt.nlayer, opt.code_size, opt.norm, n_lis_layers=opt.r_iterations, upscaling=opt.g_upscaling)
        if show_info:
            print("G:", gen)
        gen.cuda()
        prefix = "last"
        gen.load_state_dict(torch.load(g_fp))
        gen.train()

        print("Generating images for checkpoint G'%s'..." % (g_fp,))
        #imgs_by_riter = [[] for _ in range(1+opt.r_iterations)]
        images_all = []
        for i in range(nb_batches):
            code = Variable(torch.randn(opt.batch_size, opt.code_size).cuda(), volatile=True)

            #for r_idx in range(1+opt.r_iterations):
            images, _ = gen(code, n_execute_lis_layers=r_idx)
            images_np = (images.data.cpu().numpy() * 255).astype(np.uint8).transpose((0, 2, 3, 1))

            #from scipy import misc
            #print(np.average(images[0]), np.min(images[0]), np.max(images[0]))
            #print(np.average(images_fixed[0]), np.min(images_fixed[0]), np.max(images_fixed[0]))
            #misc.imshow(list(images_np)[0])
            #misc.imshow(list(images_fixed)[0])

            #imgs_by_riter[r_idx].extend(list(images_np))
            images_all.extend(images_np)

        result_str = pickle.dumps({
            "g_fp": g_fp,
            "images": images_all
        }, protocol=-1)
        queue.put(result_str)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号