sampling.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:segmentation_DLMI 作者: imatge-upc 项目源码 文件源码
def get_weighted_mask(self, image_shape, mask_shape,ROI_mask=None, labels_mask=None):

        if labels_mask is  None:
            raise ValueError('SamplingScheme error: please specify a labels_mask for this sampling scheme')
        print(np.unique(labels_mask))
        mask_boundaries = self.get_mask_boundaries(image_shape, mask_shape,ROI_mask)


        final_mask = np.zeros((self.n_categories,) + labels_mask.shape, dtype="int16")
        for index_cat in range(self.n_categories):
            final_mask[index_cat] = (labels_mask == index_cat,) * mask_boundaries

        final_mask = 1.0 * final_mask / np.reshape(np.sum(np.reshape(final_mask,(self.n_categories,-1)),axis=1),(self.n_categories,)+(1,)*len(image_shape))

        print(np.sum(np.reshape(final_mask,(self.n_categories,-1)),axis=1))
        return final_mask
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号