sklearn.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:tu-dortmund-ice-cube 作者: wjam1995 项目源码 文件源码
def threshold_weighted_confusion_matrix(y_true, y_pred, weights, th=0.5):
    """
    Computes a weighted confusion matrix with a threshold in predictions.

    Takes numpy arrays

    Arguments:
        y_true - labels
        y_pred - predictions
        weights - weights for each waveform
        th - probability threshold above which the signal class is
             considered to predict signal (default: 0.5)

    Returns:
        confusion_matrix - a numpy array containing the confusion matrix

    """
    # This statement flattens vectors from one-hot, thresholds predictions
    return weighted_confusion_matrix(y_true.nonzero()[1],
                                         y_pred[:, 1] > th, weights)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号