def weight_by_class_balance(truth, classes=None):
"""
Determines a loss weight map given the truth by balancing the classes from the classes argument.
The classes argument can be used to only include certain classes (you may for instance want to exclude the background).
"""
if classes is None:
# Include all classes
classes = np.unique(truth)
weight_map = np.zeros_like(truth, dtype=np.float32)
total_amount = np.product(truth.shape)
min_weight = sys.maxint
for c in classes:
class_mask = np.where(truth==c,1,0)
class_weight = 1/((np.sum(class_mask)+1e-8)/total_amount)
if class_weight < min_weight:
min_weight = class_weight
weight_map += (class_mask*class_weight)#/total_amount
weight_map /= min_weight
return weight_map
评论列表
文章目录