evaluation.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:chainer-cyclegan 作者: Aixile 项目源码 文件源码
def evaluation(gen_g, gen_f, test_image_folder, image_size=256, side=2):
    @chainer.training.make_extension()
    def _eval(trainer, it):
        xp = gen_g.xp
        batch = it.next()
        batchsize = len(batch)

        #x = []
        x = xp.zeros((batchsize, 3, image_size, image_size)).astype("f")
        t = xp.zeros((batchsize, 3, image_size, image_size)).astype("f")
        for i in range(batchsize):
            x[i, :] = xp.asarray(batch[i][0])
            t[i, :] = xp.asarray(batch[i][1])

        x = Variable(x)
        result = gen_g(x, test=True)
        img = result.data.get()

        img_c = img.reshape((side, side, 3, image_size, image_size))
        img_c = img_c.transpose(0,1,3,4,2)
        img_c = (img + 1) *127.5
        img_c = np.clip(img_c, 0, 255)
        img_c = img_c.astype(np.uint8)
        img_c = img_c.reshape((side, side, image_size, image_size, 3)).transpose(0,2,1,3,4).reshape((side*image_size, side*image_size, 3))[:,:,::-1]
        Image.fromarray(img_c).save(test_image_folder+"/iter_"+str(trainer.updater.iteration)+"_G.jpg")

        t = Variable(t)
        result = gen_f(t, test=True)
        img_t = result.data.get()
        img_t = img_t.reshape( (side, side, 3, image_size, image_size))
        img_t = img_t.transpose(0,1,3,4,2)
        img_t = (img + 1) *127.5
        img_t = np.clip(img_t, 0, 255)
        img_t = img_t.astype(np.uint8)
        img_t = img_t.reshape((side, side, image_size, image_size, 3)).transpose(0,2,1,3,4).reshape((side*image_size, side*image_size, 3))[:,:,::-1]
        #print(img_t)
        Image.fromarray(img_t).save(test_image_folder+"/iter_"+str(trainer.updater.iteration)+"_F.jpg")

    def evaluation(trainer):
        it = trainer.updater.get_iterator('test')
        _eval(trainer, it)

    return evaluation
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号