def vis_col_result(im, seg, gt, savefile=None):
indices_0 = np.where(gt == 0)
indices_1 = np.where(gt == 1) # metacarpal
indices_2 = np.where(gt == 2) # proximal
indices_3 = np.where(gt == 3) # middle (thumb: distal)
indices_4 = np.where(gt == 4) # distal (thumb: none)
indices_s0 = np.where(seg == 0)
indices_s1 = np.where(seg == 1)
indices_s2 = np.where(seg == 2)
indices_s3 = np.where(seg == 3)
indices_s4 = np.where(seg == 4)
im = im * 1. / im.max()
rgb_image = color.gray2rgb(im)
m0 = [0.6, 0.6, 1.]
m1 = [0.2, 1., 0.2]
m2 = [1., 1., 0.2]
m3 = [1., 0.6, 0.2]
m4 = [1., 0., 0.]
im_gt = rgb_image.copy()
im_seg = rgb_image.copy()
im_gt[indices_0[0], indices_0[1], :] *= m0
im_gt[indices_1[0], indices_1[1], :] *= m1
im_gt[indices_2[0], indices_2[1], :] *= m2
im_gt[indices_3[0], indices_3[1], :] *= m3
im_gt[indices_4[0], indices_4[1], :] *= m4
im_seg[indices_s0[0], indices_s0[1], :] *= m0
im_seg[indices_s1[0], indices_s1[1], :] *= m1
im_seg[indices_s2[0], indices_s2[1], :] *= m2
im_seg[indices_s3[0], indices_s3[1], :] *= m3
im_seg[indices_s4[0], indices_s4[1], :] *= m4
fig = plt.figure()
a = fig.add_subplot(1, 2, 1)
plt.imshow(im_seg)
a.set_title('Segmentation')
a = fig.add_subplot(1, 2, 2)
plt.imshow(im_gt)
a.set_title('Ground truth')
if savefile is not None:
plt.savefig(savefile)
else:
plt.show()
plt.close()
评论列表
文章目录