def convert_RGB_mask_to_index(im, colors, ignore_missing_labels=False):
"""
:param im: mask in RGB format (classes are RGB colors)
:param colors: the color map should be in the following format
colors = OrderedDict([
("Sky", np.array([[128, 128, 128]], dtype=np.uint8)),
("Building", np.array([[128, 0, 0], # Building
[64, 192, 0], # Wall
[0, 128, 64] # Bridge
], dtype=np.uint8)
...
])
:param ignore_missing_labels: if True the function continue also if some
pixels fail the mappint
:return: the mask in index class format
"""
out = (np.ones(im.shape[:2]) * 255).astype(np.uint8)
for grey_val, (label, rgb) in enumerate(colors.items()):
for el in rgb:
match_pxls = np.where((im == np.asarray(el)).sum(-1) == 3)
out[match_pxls] = grey_val
if ignore_missing_labels: # retrieve the void label
if [0, 0, 0] in rgb:
void_label = grey_val
# debug
# outpath = '/Users/marcus/exp/datasets/camvid/grey_test/o.png'
# imsave(outpath, out)
######
if ignore_missing_labels:
match_missing = np.where(out == 255)
if match_missing[0].size > 0:
print "Ignoring missing labels"
out[match_missing] = void_label
assert (out != 255).all(), "rounding errors or missing classes in colors"
return out.astype(np.uint8)
评论列表
文章目录