def load_image(self, im_id: str) -> Image:
logger.info('Loading {}'.format(im_id))
im_cache = Path('im_cache')
im_cache.mkdir(exist_ok=True)
im_data_path = im_cache.joinpath('{}.data'.format(im_id))
mask_path = im_cache.joinpath('{}.mask'.format(im_id))
if im_data_path.exists():
im_data = np.load(str(im_data_path))
else:
im_data = self.preprocess_image(utils.load_image(im_id))
with im_data_path.open('wb') as f:
np.save(f, im_data)
pre_buffer = self.hps.pre_buffer
if mask_path.exists() and not pre_buffer:
mask = np.load(str(mask_path))
else:
im_size = im_data.shape[1:]
poly_by_type = utils.load_polygons(im_id, im_size)
if pre_buffer:
structures = 2
poly_by_type[structures] = utils.to_multipolygon(
poly_by_type[structures].buffer(pre_buffer))
mask = np.array(
[utils.mask_for_polygons(im_size, poly_by_type[cls + 1])
for cls in range(self.hps.total_classes)],
dtype=np.uint8)
if not pre_buffer:
with mask_path.open('wb') as f:
np.save(f, mask)
if self.hps.n_channels != im_data.shape[0]:
im_data = im_data[:self.hps.n_channels]
return Image(im_id, im_data, mask[self.hps.classes])
评论列表
文章目录