def dist_mask(mask, max_dist=10):
mask = mask.astype(np.uint8)
def get_dist(m):
d = cv2.distanceTransform(m, cv2.DIST_L2, maskSize=3)
d[d > max_dist] = max_dist
return d / max_dist
dist = get_dist(mask) - get_dist(1 - mask)
# TODO - check in the notebook
# TODO - what is the proper power?
#pow = 0.5
#dist[dist > 0] = dist[dist > 0] ** pow
#dist[dist < 0] = -((-dist[dist < 0]) ** pow)
return (1 + dist) / 2 # from 0 to 1
评论列表
文章目录