utils.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:Vulcan 作者: rfratila 项目源码 文件源码
def get_confusion_matrix(prediction, truth):
    """
    Calculate the confusion matrix for classification network predictions.

    Args:
        predicted: the class matrix predicted by the network.
                   Does not take one hot vectors.
        actual: the class matrix of the ground truth
                Does not take one hot vectors.

    Returns: the confusion matrix
    """
    if len(prediction.shape) == 2:
        prediction = prediction[:, 0]
    if len(truth.shape) == 2:
        truth = truth[:, 0]

    return confusion_matrix(y_true=truth,
                            y_pred=prediction)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号