ash.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:CNNbasedMedicalSegmentation 作者: BRML 项目源码 文件源码
def tanimoto_wmap(target_in, prediction, eps=1e-8):
    '''
    Tanimoto distance, see: https://en.wikipedia.org/wiki/Jaccard_index#Other_definitions_of_Tanimoto_distance
    '''
    target_in = T.reshape(target_in, (target_in.shape[1], target_in.shape[2]))
    target = target_in[:, :2]
    wmap = T.repeat(target_in[:, 2].dimshuffle(('x', 0)), 2, axis=0).dimshuffle((1, 0))
    prediction = T.reshape(prediction, (prediction.shape[1], prediction.shape[2]))
    prediction = T.clip(prediction, eps, 1 - eps)

    target_w = T.sum(T.sqr(target * wmap), axis=0, keepdims=True)
    pred_w = T.sum(T.sqr(prediction * wmap), axis=0, keepdims=True)
    intersection_w = T.sum(target_w * pred_w, axis=0, keepdims=True)

    intersection = T.sum(target * prediction, axis=0, keepdims=True)
    prediction_sq = T.sum(T.sqr(prediction), axis=0, keepdims=True)
    target_sq = T.sum(T.sqr(target), axis=0, keepdims=True)

    loss = (target_w + pred_w - 2 * intersection_w) / (target_sq + prediction_sq - intersection)
    return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号