def metric(self, predictions, targets, num_classes=5):
"""
Computes auroc metric
Args:
predictions: 2D tensor/array, predictions of the network
targets: 2D tensor/array, ground truth labels of the network
num_classes: int, num_classes of the network
Returns:
auroc score
"""
if targets.ndim == 2:
targets = np.argmax(targets, axis=1)
if predictions.ndim == 1:
predictions = one_hot(predictions, m=num_classes)
return self._auroc(predictions, targets)
评论列表
文章目录