heatmap_model.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:MachineLearning 作者: timomernick 项目源码 文件源码
def dump_heatmaps(filename, images, heatmaps, antialias=True, multiply=True):
    images = images.numpy()
    heatmaps = heatmaps.numpy()
    all_masked = np.zeros(images.shape)
    num_images = images.shape[0]
    for i in range(num_images):
        image = images[i] / 255.0
        heatmap = heatmaps[i]

        # resize the heatmap to be the same size as the image
        if (antialias):
            interp = "bilinear"
        else:
            interp = "nearest"
        heatmap = img_as_float(scipy.misc.imresize(heatmap, [image.shape[1], image.shape[2]], interp=interp))

        # tile the heatmap in each component so it's HxWx3 like the image
        heatmap = heatmap.reshape(1, heatmap.shape[0], heatmap.shape[1])
        heatmap = np.tile(heatmap, (3,1,1))

        # mask the image by the heatmap
        if multiply:
            masked = image * heatmap
        else:
            masked = image + heatmap

        all_masked[i] = masked

    vutils.save_image(torch.FloatTensor(all_masked), filename, normalize=True)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号