visualize.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:GAN 作者: lyakaap 项目源码 文件源码
def out_generated_image(gen, dis, rows, cols, seed, dst):
    @chainer.training.make_extension()
    def make_image(trainer):
        np.random.seed(seed)
        n_images = rows * cols
        xp = gen.xp
        z = Variable(xp.asarray(gen.make_hidden(n_images)))
        with chainer.using_config('train', False):
            x = gen(z)
        x = chainer.cuda.to_cpu(x.data)
        np.random.seed()

        # gen_output_activation_func is sigmoid
        x = np.asarray(np.clip(x * 255, 0.0, 255.0), dtype=np.uint8)
        # gen output_activation_func is tanh
        #x = np.asarray(np.clip((x+1) * 0.5 * 255, 0.0, 255.0), dtype=np.uint8)
        _, _, H, W = x.shape
        x = x.reshape((rows, cols, 1, H, W))
        x = x.transpose(0, 3, 1, 4, 2)
        x = x.reshape((rows * H, cols * W))
        preview_dir = '{}/preview_LSGAN_pixel_shuffler'.format(dst)
        preview_path = preview_dir +\
            '/image{:0>8}.png'.format(trainer.updater.iteration)
        if not os.path.exists(preview_dir):
            os.makedirs(preview_dir)
        Image.fromarray(x).save(preview_path)
    return make_image
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号