infer.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:wavenet 作者: rampage644 项目源码 文件源码
def generate_and_save_samples(sample_fn, height, width, channels, count, filename):
    def save_images(images):
        images = images.reshape((count, count, channels, height, width))
        images = images.transpose(1, 3, 0, 4, 2)
        images = images.reshape((height * count, width * count, channels))
        scipy.misc.toimage(images, cmin=0.0, cmax=255.0).save(filename)

    samples = chainer.Variable(
        chainer.cuda.cupy.zeros((count ** 2, channels, height, width), dtype='float32'))

    with tqdm.tqdm(total=height*width*channels) as bar:
        for i in range(height):
            for j in range(width):
                for k in range(channels):
                    probs = F.softmax(sample_fn(samples))[:, :, k, i, j]
                    _, level_count = probs.shape
                    samples.data[:, k, i, j] = chainer.cuda.to_gpu(utils.sample_from(probs.data.get()) / (level_count - 1))
                    bar.update()
    samples.to_cpu()

    save_images(samples.data * 255.0)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号