def save(mask, img, blurred):
mask = mask.cpu().data.numpy()[0]
mask = np.transpose(mask, (1, 2, 0))
mask = (mask - np.min(mask)) / np.max(mask)
mask = 1 - mask
heatmap = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
cam = 1.0*heatmap + np.float32(img)/255
cam = cam / np.max(cam)
img = np.float32(img) / 255
perturbated = np.multiply(1 - mask, img) + np.multiply(mask, blurred)
cv2.imwrite("perturbated.png", np.uint8(255*perturbated))
cv2.imwrite("heatmap.png", np.uint8(255*heatmap))
cv2.imwrite("mask.png", np.uint8(255*mask))
cv2.imwrite("cam.png", np.uint8(255*cam))
评论列表
文章目录