classify.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:ISM2017 作者: ybayle 项目源码 文件源码
def plot_precision_recall(indir, gts_file, outdir):
    groundtruths = read_item_tag(gts_file)
    plt.figure(1)

    indir = utils.abs_path_dir(indir)
    for item in os.listdir(indir):
        if ".csv" in item:
            isrcs = read_preds(indir + "/" + item)
            test_groundtruths = []
            predictions = []
            for isrc in isrcs:
                if isrc in groundtruths:
                    test_groundtruths.append(groundtruths[isrc])
                    predictions.append(isrcs[isrc])
            test_groundtruths = [tag=="s" for tag in test_groundtruths]
            precision, recall, _ = precision_recall_curve(test_groundtruths, predictions)
            plt.plot(recall, precision, label=item[:-4] + " (" + str(round(average_precision_score(test_groundtruths, predictions), 3)) + ")")

    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.ylim([0.0, 1.05])
    plt.xlim([-0.05, 1.05])
    plt.title('Precision-Recall curve for Algo (AUC)')
    plt.legend(loc='best')
    plt.savefig(outdir + "precision_recall.png", dpi=200, bbox_inches="tight")
    # plt.show()
    plt.close()
    utils.print_success("Precision-Recall curve created in " + outdir)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号