def create_checkpoint_mask(img, mask, predicted_mask):
p_mask = predicted_mask
assert p_mask.shape[0] < p_mask.shape[1]
if p_mask.shape == (CARVANA_H, CARVANA_W + 2):
p_mask = p_mask[:, 1:-1]
else:
p_mask = cv2.resize(p_mask, (CARVANA_W, CARVANA_H),
interpolation=cv2.INTER_NEAREST)
p_mask = (p_mask > 0.5).astype(np.uint8)
true_mask = mask_to_bgr(mask, 0, 255, 0)
p_mask = mask_to_bgr(p_mask, 0, 0, 255)
w = cv2.addWeighted(img, 1.0, true_mask, 0.3, 0)
w = cv2.addWeighted(w, 1.0, p_mask, 0.5, 0)
return w
评论列表
文章目录