def get_weighted_mask(self, image_shape, mask_shape,ROI_mask=None, labels_mask=None):
if labels_mask is None:
raise ValueError('SamplingScheme error: please specify a labels_mask for this sampling scheme')
print(np.unique(labels_mask))
mask_boundaries = self.get_mask_boundaries(image_shape, mask_shape,ROI_mask)
final_mask = np.zeros((self.n_categories,) + labels_mask.shape, dtype="int16")
for index_cat in range(self.n_categories):
final_mask[index_cat] = (labels_mask == index_cat,) * mask_boundaries
final_mask = 1.0 * final_mask / np.reshape(np.sum(np.reshape(final_mask,(self.n_categories,-1)),axis=1),(self.n_categories,)+(1,)*len(image_shape))
print(np.sum(np.reshape(final_mask,(self.n_categories,-1)),axis=1))
return final_mask
评论列表
文章目录