def draw_seg(self, img, seg_gt, segmentation, name):
"""Applies generated segmentation mask to an image"""
palette = np.load('Extra/palette.npy').tolist()
img_size = (img.shape[1], img.shape[0])
segmentation = cv2.resize(segmentation, dsize=img_size,
interpolation=cv2.INTER_NEAREST)
image = Image.fromarray((img * 255).astype('uint8'))
segmentation_draw = Image.fromarray((segmentation).astype('uint8'), 'P')
segmentation_draw.putpalette(palette)
segmentation_draw.save(self.directory + '/%s_segmentation.png' % name, 'PNG')
image.save(self.directory + '/%s.jpg' % name, 'JPEG')
if seg_gt:
seg_gt_draw = Image.fromarray((seg_gt).astype('uint8'), 'P')
seg_gt_draw.putpalette(palette)
seg_gt_draw.save(self.directory + '/%s_seg_gt.png' % name, 'PNG')
评论列表
文章目录