pipeline.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:syracuse_public 作者: dssg 项目源码 文件源码
def plot_precision_recall_n(y_true, y_prob, model_name, pdf=None):
    y_score = y_prob
    precision_curve, recall_curve, pr_thresholds = precision_recall_curve(
        y_true, y_score)
    precision_curve = precision_curve[:-1]
    recall_curve = recall_curve[:-1]
    pct_above_per_thresh = []
    number_scored = len(y_score)
    for value in pr_thresholds:
        num_above_thresh = len(y_score[y_score >= value])
        pct_above_thresh = num_above_thresh / float(number_scored)
        pct_above_per_thresh.append(pct_above_thresh)
    pct_above_per_thresh = np.array(pct_above_per_thresh)
    plt.clf()
    fig, ax1 = plt.subplots()
    ax1.plot(pct_above_per_thresh, precision_curve, 'b')
    ax1.set_xlabel('percent of population')
    ax1.set_ylabel('precision', color='b')
    ax2 = ax1.twinx()
    ax2.plot(pct_above_per_thresh, recall_curve, 'r')
    ax2.set_ylabel('recall', color='r')

    name = model_name
    plt.title(name)
    if pdf:
        pdf.savefig()
        plt.close()
    else:
        plt.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号