def overlay_heatmap(image, heatmap, cmap='jet', vmin=0, vmax=1, img_ratio=0.4):
""" create a visualization of the image with overlaid heatmap """
img_gray = image
if len(image.shape) == 3:
img_gray = skimage.color.rgb2gray(image)
elif len(image.shape) != 2:
raise Exception('Image should be grayscale or rgb')
heatmap_norm = (heatmap - vmin) / (vmax - vmin)
cmap = mpl.cm.get_cmap(cmap)
heatmap_vis = cmap(heatmap_norm)
img_gray_3plane = np.repeat(img_gray.reshape(np.append(img_gray.shape, 1)), 3, axis=2)
heatmap_overlay = (1.0 - img_ratio) * heatmap_vis[:,:,0:3] + img_ratio * img_gray_3plane
mask = np.isnan(heatmap)
heatmap_overlay[mask] = img_gray_3plane[mask]
return heatmap_overlay
评论列表
文章目录