def tanimoto_wmap(target_in, prediction, eps=1e-8):
'''
Tanimoto distance, see: https://en.wikipedia.org/wiki/Jaccard_index#Other_definitions_of_Tanimoto_distance
'''
target_in = T.reshape(target_in, (target_in.shape[1], target_in.shape[2]))
target = target_in[:, :2]
wmap = T.repeat(target_in[:, 2].dimshuffle(('x', 0)), 2, axis=0).dimshuffle((1, 0))
prediction = T.reshape(prediction, (prediction.shape[1], prediction.shape[2]))
prediction = T.clip(prediction, eps, 1 - eps)
target_w = T.sum(T.sqr(target * wmap), axis=0, keepdims=True)
pred_w = T.sum(T.sqr(prediction * wmap), axis=0, keepdims=True)
intersection_w = T.sum(target_w * pred_w, axis=0, keepdims=True)
intersection = T.sum(target * prediction, axis=0, keepdims=True)
prediction_sq = T.sum(T.sqr(prediction), axis=0, keepdims=True)
target_sq = T.sum(T.sqr(target), axis=0, keepdims=True)
loss = (target_w + pred_w - 2 * intersection_w) / (target_sq + prediction_sq - intersection)
return loss
评论列表
文章目录