metrics.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:speech_ml 作者: coopie 项目源码 文件源码
def confusion_matrix_metric(targets, predictions, threshold=0.5):
    """
    Compute confusion matrix.

    Works for arbitrary number of classes. If the shape of the data is one,
    treat as a binary classification with `threshold` as the cutoff point.
    """
    assert targets.ndim == predictions.ndim == 2
    assert targets.shape == predictions.shape

    if targets.shape[1] == 1:
        targets = targets > threshold
        predictions = predictions > threshold
    else:
        targets = np.argmax(targets, axis=1)
        predictions = np.argmax(predictions, axis=1)

    targets = targets.flatten()
    predictions = predictions.flatten()

    conf_matrix = confusion_matrix(targets, predictions)
    return [conf_matrix], ['confusion_matrix']
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号