def dice_loss(y_pred, y_true, void_class, class_for_dice=1):
'''
Dice loss -- works for only binary classes.
y_pred is a softmax output
y_true is one hot
'''
smooth = 1
y_true_f = T.flatten(y_true[:, class_for_dice, :, :])
y_true_f = T.cast(y_true_f, 'int32')
y_pred_f = T.flatten(y_pred[:, class_for_dice, :, :])
# remove void classes from dice
if len(void_class):
for i in range(len(void_class)):
# get idx of non void classes and remove void classes
# from y_true and y_pred
idxs = T.neq(y_true_f, void_class[i]).nonzero()
y_pred_f = y_pred_f[idxs]
y_true_f = y_true_f[idxs]
intersection = T.sum(y_true_f * y_pred_f)
return -(2.*intersection+smooth) / (T.sum(y_true_f)+T.sum(y_pred_f)+smooth)
评论列表
文章目录