def accuracy(y_pred, y_true, void_labels, one_hot=False):
assert (y_pred.ndim == 2) or (y_pred.ndim == 1)
# y_pred to indices
if y_pred.ndim == 2:
y_pred = T.argmax(y_pred, axis=1)
if one_hot:
y_true = T.argmax(y_true, axis=1)
# Compute accuracy
acc = T.eq(y_pred, y_true).astype(_FLOATX)
# Create mask
mask = T.ones_like(y_true, dtype=_FLOATX)
for el in void_labels:
indices = T.eq(y_true, el).nonzero()
if any(indices):
mask = T.set_subtensor(mask[indices], 0.)
# Apply mask
acc *= mask
acc = T.sum(acc) / T.sum(mask)
return acc
评论列表
文章目录