def encode_segmap(self, mask):
mask = mask.astype(int)
label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16)
for i, label in enumerate(self.get_pascal_labels()):
label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = i
label_mask = label_mask.astype(int)
return label_mask
评论列表
文章目录