def get_confusion_matrix(prediction, truth):
"""
Calculate the confusion matrix for classification network predictions.
Args:
predicted: the class matrix predicted by the network.
Does not take one hot vectors.
actual: the class matrix of the ground truth
Does not take one hot vectors.
Returns: the confusion matrix
"""
if len(prediction.shape) == 2:
prediction = prediction[:, 0]
if len(truth.shape) == 2:
truth = truth[:, 0]
return confusion_matrix(y_true=truth,
y_pred=prediction)
评论列表
文章目录