metrics.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:iterative_inference_segm 作者: adri-romsor 项目源码 文件源码
def dice_loss(y_pred, y_true, void_class, class_for_dice=1):
        '''
        Dice loss -- works for only binary classes.
        y_pred is a softmax output
        y_true is one hot
        '''
        smooth = 1
        y_true_f = T.flatten(y_true[:, class_for_dice, :, :])
        y_true_f = T.cast(y_true_f, 'int32')
        y_pred_f = T.flatten(y_pred[:, class_for_dice, :, :])
        # remove void classes from dice
        if len(void_class):
            for i in range(len(void_class)):
                # get idx of non void classes and remove void classes
                # from y_true and y_pred
                idxs = T.neq(y_true_f, void_class[i]).nonzero()
                y_pred_f = y_pred_f[idxs]
                y_true_f = y_true_f[idxs]

        intersection = T.sum(y_true_f * y_pred_f)
        return -(2.*intersection+smooth) / (T.sum(y_true_f)+T.sum(y_pred_f)+smooth)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号