def jaccard(y_pred, y_true, n_classes, one_hot=False):
assert (y_pred.ndim == 2) or (y_pred.ndim == 1)
# y_pred to indices
if y_pred.ndim == 2:
y_pred = T.argmax(y_pred, axis=1)
if one_hot:
y_true = T.argmax(y_true, axis=1)
# Compute confusion matrix
cm = T.zeros((n_classes, n_classes))
for i in range(n_classes):
for j in range(n_classes):
cm = T.set_subtensor(
cm[i, j], T.sum(T.eq(y_pred, i) * T.eq(y_true, j)))
# Compute Jaccard Index
TP_perclass = T.cast(cm.diagonal(), _FLOATX)
FP_perclass = cm.sum(1) - TP_perclass
FN_perclass = cm.sum(0) - TP_perclass
num = TP_perclass
denom = TP_perclass + FP_perclass + FN_perclass
return T.stack([num, denom], axis=0)
评论列表
文章目录