evaluate.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:relation_classification 作者: hxy8149989 项目源码 文件源码
def save_prcurve(prob, answer, model_name, save_fn, use_neg=True):
    """
    save prc curve
    """
    if not use_neg:
        prob_dn = []
        ans_dn = []
        for p in prob:
            prob_dn.append(p[1:])
        for ans in answer:
            ans_dn.append(ans[1:])
        prob = np.reshape(np.array(prob_dn), (-1))
        ans = np.reshape(np.array(ans_dn), (-1))
    else:
        prob = np.reshape(prob, (-1))
        ans = np.reshape(answer, (-1))

    precision, recall, threshold = precision_recall_curve(ans, prob)
    average_precision = average_precision_score(ans, prob)

    plt.clf()
    plt.plot(recall[:], precision[:], lw=2, color='navy', label=model_name)
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    # plt.ylim([0.3, 1.0])
    # plt.xlim([0.0, 0.4])
    plt.title('Precision-Recall Area={0:0.2f}'.format(average_precision))
    plt.legend(loc="upper right")
    plt.grid(True)
    plt.savefig(save_fn)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号