def get_segmentation_image(segdb, config):
"""
propocess image and return segdb
:param segdb: a list of segdb
:return: list of img as mxnet format
"""
num_images = len(segdb)
assert num_images > 0, 'No images'
processed_ims = []
processed_segdb = []
processed_seg_cls_gt = []
for i in range(num_images):
seg_rec = segdb[i]
assert os.path.exists(seg_rec['image']), '%s does not exist'.format(seg_rec['image'])
im = np.array(cv2.imread(seg_rec['image']))
new_rec = seg_rec.copy()
scale_ind = random.randrange(len(config.SCALES))
target_size = config.SCALES[scale_ind][0]
max_size = config.SCALES[scale_ind][1]
im, im_scale = resize(im, target_size, max_size, stride=config.network.IMAGE_STRIDE)
im_tensor = transform(im, config.network.PIXEL_MEANS)
im_info = [im_tensor.shape[2], im_tensor.shape[3], im_scale]
new_rec['im_info'] = im_info
seg_cls_gt = np.array(Image.open(seg_rec['seg_cls_path']))
seg_cls_gt, seg_cls_gt_scale = resize(
seg_cls_gt, target_size, max_size, stride=config.network.IMAGE_STRIDE, interpolation=cv2.INTER_NEAREST)
seg_cls_gt_tensor = transform_seg_gt(seg_cls_gt)
processed_ims.append(im_tensor)
processed_segdb.append(new_rec)
processed_seg_cls_gt.append(seg_cls_gt_tensor)
return processed_ims, processed_seg_cls_gt, processed_segdb
评论列表
文章目录