unet.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:neural-fonts 作者: periannath 项目源码 文件源码
def infer(self, source_obj, embedding_ids, model_dir, save_dir, progress_file):
        source_provider = InjectDataProvider(source_obj, None)

        with open(progress_file, 'a') as f:
            f.write("Start")

        if isinstance(embedding_ids, int) or len(embedding_ids) == 1:
            embedding_id = embedding_ids if isinstance(embedding_ids, int) else embedding_ids[0]
            source_iter = source_provider.get_single_embedding_iter(self.batch_size, embedding_id)
        else:
            source_iter = source_provider.get_random_embedding_iter(self.batch_size, embedding_ids)

        tf.global_variables_initializer().run()
        saver = tf.train.Saver(var_list=self.retrieve_generator_vars())
        self.restore_model(saver, model_dir)

        def save_imgs(imgs, count):
            p = os.path.join(save_dir, "inferred_%04d.png" % count)
            save_concat_images(imgs, img_path=p)
#            print("generated images saved at %s" % p)

        def save_sample(imgs, code):
            p = os.path.join(save_dir, "inferred_%s.png" % code)
            save_concat_images(imgs, img_path=p)
#            print("generated images saved at %s" % p)

        count = 0
        batch_buffer = list()
        for labels, codes, source_imgs in source_iter:
            fake_imgs = self.generate_fake_samples(source_imgs, labels)[0]
            for i in range(len(fake_imgs)):
                # Denormalize image
                gray_img = np.uint8(fake_imgs[i][:,:,0]*127.5+127.5)
                pil_img = Image.fromarray(gray_img, 'L')
                # Apply bilateralFilter
                cv_img = np.array(pil_img)
                cv_img = bilateralFilter(cv_img, 5, 10, 10)
                pil_img = Image.fromarray(cv_img)
                # Increase contrast
                enhancer = ImageEnhance.Contrast(pil_img)
                en_img = enhancer.enhance(1.5)
                # Normalize image
                fake_imgs[i][:,:,0] = Image.fromarray(np.array(en_img)/127.5 - 1.)
#                save_sample(fake_imgs[i], codes[i])
            merged_fake_images = merge(scale_back(fake_imgs), [self.batch_size, 1])
            batch_buffer.append(merged_fake_images)
            if len(batch_buffer) == 1:
                save_sample(batch_buffer, codes[0])
                batch_buffer = list()
            count += 1
        if batch_buffer:
            # last batch
            save_imgs(batch_buffer, count)
        with open(progress_file, 'a') as f:
            f.write("Done")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号