evaluation.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:chainer-gan-experiments 作者: Aixile 项目源码 文件源码
def gan_sampling_tags(gen, eval_folder, gpu, rows=6, cols=6, latent_len=128, attr_len=38, threshold=0.25):
    @chainer.training.make_extension()
    def get_fake_tag():
        prob2 = np.random.rand(attr_len)
        tags = np.zeros((attr_len)).astype("f")
        tags[:] = -1.0
        tags[np.argmax(prob2[0:13])]=1.0
        tags[27 + np.argmax(prob2[27:])] = 1.0
        prob2[prob2<threshold] = -1.0
        prob2[prob2>=threshold] = 1.0
        for i in range(13, 27):
            tags[i] = prob2[i]
        return tags

    def get_fake_tag_batch():
        xp = gen.xp
        batch = rows*cols
        tags = xp.zeros((batch, attr_len)).astype("f")
        for i in range(batch):
            tags[i] = xp.asarray(get_fake_tag())
        return tags

    def samples_generation(trainer):
        if not os.path.exists(eval_folder):
            os.makedirs(eval_folder)
        z = np.random.normal(size=(rows*cols, latent_len)).astype("f")
        if gpu>=0:
            z = cuda.to_gpu(z)
        tags =get_fake_tag_batch()
        z = Variable(z, volatile=True)
        tags = Variable(tags, volatile=True)
        imgs = gen(F.concat([z,tags]), test=True)
        save_images_grid(imgs, path=eval_folder+"/iter_"+str(trainer.updater.iteration)+".jpg",
            grid_w=rows, grid_h=cols)

    return samples_generation
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号