def get_class_weights(self,subject_list, mask_bool = True):
class_frequencies = np.zeros(self.n_classes)
for subj in subject_list:
labels = subj.load_labels()
if mask_bool == 'ROI':
mask = subj.load_ROI_mask()
class_frequencies += np.bincount(labels.flatten().astype('int'), weights=mask.flatten().astype('int'),
minlength=self.n_classes)
elif mask_bool == 'labels':
mask = np.zeros_like(labels)
mask[labels > 0] = 1
# print(np.bincount(labels.flatten().astype('int'), weights=mask.flatten().astype('int'),
# minlength=self.n_classes))
class_frequencies += np.bincount(labels.flatten().astype('int'), weights=mask.flatten().astype('int'),
minlength=self.n_classes+1)[1:]
else :
class_frequencies += np.bincount(labels.flatten().astype('int'),
minlength=self.n_classes)
class_frequencies = class_frequencies / np.sum(class_frequencies)
class_weight = np.sort(class_frequencies)[int(np.ceil(1.0*self.n_classes/2))] / class_frequencies
class_weight[np.where(class_frequencies == 0)[0]] = 0 #avoid infinit weight
return class_weight
评论列表
文章目录