metrics.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:qtim_ROP 作者: QTIM-Lab 项目源码 文件源码
def plot_PR_by_class(y_pred, y_true, classes, out_path):

    best_thresh = {}
    for class_name, c in classes.items():  # for each class

        # Compute ROC curve
        precision, recall, thresholds = precision_recall_curve(y_true[:, c], y_pred[:, c])
        pr_auc = auc(recall, precision)

        # Plot PR curve
        plt.plot(recall, precision, label='{}, AUC = {:.3f}'.format(class_name, pr_auc))

        # Calculate J statistic
        J = [j_statistic(y_true, y_pred, t) for t in thresholds]
        j_best = np.argmax(J)

        # Store best threshold for each class
        best_thresh[class_name] = J[j_best]

    return best_thresh
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号