def draw_label(label, img, n_class, label_titles, bg_label=0):
"""Convert label to rgb with label titles.
@param label_title: label title for each labels.
@type label_title: dict
"""
from PIL import Image
from scipy.misc import fromimage
from skimage.color import label2rgb
from skimage.transform import resize
colors = labelcolormap(n_class)
label_viz = label2rgb(label, img, colors=colors[1:], bg_label=bg_label)
# label 0 color: (0, 0, 0, 0) -> (0, 0, 0, 255)
label_viz[label == 0] = 0
# plot label titles on image using matplotlib
plt.subplots_adjust(left=0, right=1, top=1, bottom=0,
wspace=0, hspace=0)
plt.margins(0, 0)
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.axis('off')
# plot image
plt.imshow(label_viz)
# plot legend
plt_handlers = []
plt_titles = []
for label_value in np.unique(label):
if label_value not in label_titles:
continue
fc = colors[label_value]
p = plt.Rectangle((0, 0), 1, 1, fc=fc)
plt_handlers.append(p)
plt_titles.append(label_titles[label_value])
plt.legend(plt_handlers, plt_titles, loc='lower right', framealpha=0.5)
# convert plotted figure to np.ndarray
f = StringIO.StringIO()
plt.savefig(f, bbox_inches='tight', pad_inches=0)
result_img_pil = Image.open(f)
result_img = fromimage(result_img_pil, mode='RGB')
result_img = resize(result_img, img.shape, preserve_range=True)
result_img = result_img.astype(img.dtype)
return result_img
utils.py 文件源码
python
阅读 25
收藏 0
点赞 0
评论 0
评论列表
文章目录