def infer(self, source_obj, embedding_ids, model_dir, save_dir, progress_file):
source_provider = InjectDataProvider(source_obj, None)
with open(progress_file, 'a') as f:
f.write("Start")
if isinstance(embedding_ids, int) or len(embedding_ids) == 1:
embedding_id = embedding_ids if isinstance(embedding_ids, int) else embedding_ids[0]
source_iter = source_provider.get_single_embedding_iter(self.batch_size, embedding_id)
else:
source_iter = source_provider.get_random_embedding_iter(self.batch_size, embedding_ids)
tf.global_variables_initializer().run()
saver = tf.train.Saver(var_list=self.retrieve_generator_vars())
self.restore_model(saver, model_dir)
def save_imgs(imgs, count):
p = os.path.join(save_dir, "inferred_%04d.png" % count)
save_concat_images(imgs, img_path=p)
# print("generated images saved at %s" % p)
def save_sample(imgs, code):
p = os.path.join(save_dir, "inferred_%s.png" % code)
save_concat_images(imgs, img_path=p)
# print("generated images saved at %s" % p)
count = 0
batch_buffer = list()
for labels, codes, source_imgs in source_iter:
fake_imgs = self.generate_fake_samples(source_imgs, labels)[0]
for i in range(len(fake_imgs)):
# Denormalize image
gray_img = np.uint8(fake_imgs[i][:,:,0]*127.5+127.5)
pil_img = Image.fromarray(gray_img, 'L')
# Apply bilateralFilter
cv_img = np.array(pil_img)
cv_img = bilateralFilter(cv_img, 5, 10, 10)
pil_img = Image.fromarray(cv_img)
# Increase contrast
enhancer = ImageEnhance.Contrast(pil_img)
en_img = enhancer.enhance(1.5)
# Normalize image
fake_imgs[i][:,:,0] = Image.fromarray(np.array(en_img)/127.5 - 1.)
# save_sample(fake_imgs[i], codes[i])
merged_fake_images = merge(scale_back(fake_imgs), [self.batch_size, 1])
batch_buffer.append(merged_fake_images)
if len(batch_buffer) == 1:
save_sample(batch_buffer, codes[0])
batch_buffer = list()
count += 1
if batch_buffer:
# last batch
save_imgs(batch_buffer, count)
with open(progress_file, 'a') as f:
f.write("Done")
评论列表
文章目录