test_meters.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:tnt 作者: pytorch 项目源码 文件源码
def testClassErrorMeter(self):
        mtr = meter.ClassErrorMeter(topk=[1])
        output = torch.eye(3)
        if hasattr(torch, "arange"):
            target = torch.arange(0, 3)
        else:
            target = torch.range(0, 2)
        mtr.add(output, target)
        err = mtr.value()

        self.assertEqual(err, [0], "All should be correct")

        target[0] = 1
        target[1] = 0
        target[2] = 0
        mtr.add(output, target)
        err = mtr.value()
        self.assertEqual(err, [50.0], "Half should be correct")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号