def dump_heatmaps(filename, images, heatmaps, antialias=True, multiply=True):
images = images.numpy()
heatmaps = heatmaps.numpy()
all_masked = np.zeros(images.shape)
num_images = images.shape[0]
for i in range(num_images):
image = images[i] / 255.0
heatmap = heatmaps[i]
# resize the heatmap to be the same size as the image
if (antialias):
interp = "bilinear"
else:
interp = "nearest"
heatmap = img_as_float(scipy.misc.imresize(heatmap, [image.shape[1], image.shape[2]], interp=interp))
# tile the heatmap in each component so it's HxWx3 like the image
heatmap = heatmap.reshape(1, heatmap.shape[0], heatmap.shape[1])
heatmap = np.tile(heatmap, (3,1,1))
# mask the image by the heatmap
if multiply:
masked = image * heatmap
else:
masked = image + heatmap
all_masked[i] = masked
vutils.save_image(torch.FloatTensor(all_masked), filename, normalize=True)
评论列表
文章目录