def gan_sampling_tags(gen, eval_folder, gpu, rows=6, cols=6, latent_len=128, attr_len=38, threshold=0.25):
@chainer.training.make_extension()
def get_fake_tag():
prob2 = np.random.rand(attr_len)
tags = np.zeros((attr_len)).astype("f")
tags[:] = -1.0
tags[np.argmax(prob2[0:13])]=1.0
tags[27 + np.argmax(prob2[27:])] = 1.0
prob2[prob2<threshold] = -1.0
prob2[prob2>=threshold] = 1.0
for i in range(13, 27):
tags[i] = prob2[i]
return tags
def get_fake_tag_batch():
xp = gen.xp
batch = rows*cols
tags = xp.zeros((batch, attr_len)).astype("f")
for i in range(batch):
tags[i] = xp.asarray(get_fake_tag())
return tags
def samples_generation(trainer):
if not os.path.exists(eval_folder):
os.makedirs(eval_folder)
z = np.random.normal(size=(rows*cols, latent_len)).astype("f")
if gpu>=0:
z = cuda.to_gpu(z)
tags =get_fake_tag_batch()
z = Variable(z, volatile=True)
tags = Variable(tags, volatile=True)
imgs = gen(F.concat([z,tags]), test=True)
save_images_grid(imgs, path=eval_folder+"/iter_"+str(trainer.updater.iteration)+".jpg",
grid_w=rows, grid_h=cols)
return samples_generation
评论列表
文章目录