dataset.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:segmentation_DLMI 作者: imatge-upc 项目源码 文件源码
def get_class_weights(self,subject_list, mask_bool = True):

        class_frequencies = np.zeros(self.n_classes)

        for subj in subject_list:
            labels = subj.load_labels()
            if mask_bool == 'ROI':
                mask = subj.load_ROI_mask()
                class_frequencies += np.bincount(labels.flatten().astype('int'), weights=mask.flatten().astype('int'),
                                                 minlength=self.n_classes)
            elif mask_bool == 'labels':
                mask = np.zeros_like(labels)
                mask[labels > 0] = 1
                # print(np.bincount(labels.flatten().astype('int'), weights=mask.flatten().astype('int'),
                #                                  minlength=self.n_classes))
                class_frequencies += np.bincount(labels.flatten().astype('int'), weights=mask.flatten().astype('int'),
                                                 minlength=self.n_classes+1)[1:]
            else :
                class_frequencies += np.bincount(labels.flatten().astype('int'),
                                                 minlength=self.n_classes)

        class_frequencies = class_frequencies / np.sum(class_frequencies)
        class_weight = np.sort(class_frequencies)[int(np.ceil(1.0*self.n_classes/2))] / class_frequencies
        class_weight[np.where(class_frequencies == 0)[0]] = 0 #avoid infinit weight

        return class_weight
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号