unet.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:luna16 作者: gzuidhof 项目源码 文件源码
def score_metrics(out, target_var, weight_map, l2_loss=0):
    _EPSILON=1e-8

    out_flat = out.dimshuffle(1,0,2,3).flatten(ndim=2).dimshuffle(1,0)
    target_flat = target_var.dimshuffle(1,0,2,3).flatten(ndim=1)
    weight_flat = weight_map.dimshuffle(1,0,2,3).flatten(ndim=1)

    prediction = lasagne.nonlinearities.softmax(out_flat)
    prediction_binary = T.argmax(prediction, axis=1)

    dice_score = (T.sum(T.eq(2, prediction_binary+target_flat))*2.0 /
                    (T.sum(prediction_binary) + T.sum(target_flat)))

    loss = lasagne.objectives.categorical_crossentropy(T.clip(prediction,_EPSILON,1-_EPSILON), target_flat)
    loss = loss * weight_flat
    loss = loss.mean()
    loss += l2_loss

    accuracy = T.mean(T.eq(prediction_binary, target_flat),
                      dtype=theano.config.floatX)

    return loss, accuracy, dice_score, target_flat, prediction, prediction_binary
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号