def add(self, outputs, targets):
outputs = to_numpy(outputs)
targets = to_numpy(targets)
if np.ndim(targets) == 2:
targets = np.argmax(targets, 1)
assert np.ndim(outputs) == 2, 'wrong output size (2D expected)'
assert np.ndim(targets) == 1, 'wrong target size (1D or 2D expected)'
assert targets.shape[0] == outputs.shape[0], 'number of outputs and targets do not match'
top_k = self.top_k
max_k = int(top_k[-1])
predict = torch.from_numpy(outputs).topk(max_k, 1, True, True)[1].numpy()
correct = (predict == targets[:, np.newaxis].repeat(predict.shape[1], 1))
self.size += targets.shape[0]
for k in top_k:
self.corrects[k] += correct[:, :k].sum()
评论列表
文章目录