def scores(output, target, threshold=0.5):
# Count true positives, true negatives, false positives and false negatives.
outputr = (output > threshold).long()
target = target.long()
a_sum = 0.0
p_sum = 0.0
r_sum = 0.0
f2_sum = 0.0
def _safe_size(t, n=0):
if n < len(t.size()):
return t.size(n)
else:
return 0
count = 0
for o, t in zip(outputr, target):
tp = _safe_size(torch.nonzero(o * t))
tn = _safe_size(torch.nonzero((o - 1) * (t - 1)))
fp = _safe_size(torch.nonzero(o * (t - 1)))
fn = _safe_size(torch.nonzero((o - 1) * t))
a = (tp + tn) / (tp + fp + fn + tn)
if tp == 0 and fp == 0 and fn == 0:
p = 1.0
r = 1.0
f2 = 1.0
elif tp == 0 and (fp > 0 or fn > 0):
p = 0.0
r = 0.0
f2 = 0.0
else:
p = tp / (tp + fp)
r = tp / (tp + fn)
f2 = (5 * p * r) / (4 * p + r)
a_sum += a
p_sum += p
r_sum += r
f2_sum += f2
count += 1
accuracy = a_sum / count
precision = p_sum / count
recall = r_sum / count
fmeasure = f2_sum / count
return accuracy, precision, recall, fmeasure
评论列表
文章目录