def testClassErrorMeter(self):
mtr = meter.ClassErrorMeter(topk=[1])
output = torch.eye(3)
if hasattr(torch, "arange"):
target = torch.arange(0, 3)
else:
target = torch.range(0, 2)
mtr.add(output, target)
err = mtr.value()
self.assertEqual(err, [0], "All should be correct")
target[0] = 1
target[1] = 0
target[2] = 0
mtr.add(output, target)
err = mtr.value()
self.assertEqual(err, [50.0], "Half should be correct")
评论列表
文章目录