loss_weighting.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:luna16 作者: gzuidhof 项目源码 文件源码
def weight_by_class_balance(truth, classes=None):
    """
    Determines a loss weight map given the truth by balancing the classes from the classes argument.
    The classes argument can be used to only include certain classes (you may for instance want to exclude the background).
    """

    if classes is None:
        # Include all classes
        classes = np.unique(truth)

    weight_map = np.zeros_like(truth, dtype=np.float32)
    total_amount = np.product(truth.shape)

    for c in classes:
        class_mask = np.where(truth==c,1,0)
        class_weight = 1/((np.sum(class_mask)+1e-8)/total_amount)

        weight_map += (class_mask*class_weight)#/total_amount

    return weight_map
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号