meter.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:python-utils 作者: zhijian-liu 项目源码 文件源码
def add(self, outputs, targets):
        outputs = to_numpy(outputs)
        targets = to_numpy(targets)

        if np.ndim(targets) == 2:
            targets = np.argmax(targets, 1)

        assert np.ndim(outputs) == 2, 'wrong output size (2D expected)'
        assert np.ndim(targets) == 1, 'wrong target size (1D or 2D expected)'
        assert targets.shape[0] == outputs.shape[0], 'number of outputs and targets do not match'

        top_k = self.top_k
        max_k = int(top_k[-1])

        predict = torch.from_numpy(outputs).topk(max_k, 1, True, True)[1].numpy()
        correct = (predict == targets[:, np.newaxis].repeat(predict.shape[1], 1))

        self.size += targets.shape[0]
        for k in top_k:
            self.corrects[k] += correct[:, :k].sum()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号