def plot_precision_recall(indir, gts_file, outdir):
groundtruths = read_item_tag(gts_file)
plt.figure(1)
indir = utils.abs_path_dir(indir)
for item in os.listdir(indir):
if ".csv" in item:
isrcs = read_preds(indir + "/" + item)
test_groundtruths = []
predictions = []
for isrc in isrcs:
if isrc in groundtruths:
test_groundtruths.append(groundtruths[isrc])
predictions.append(isrcs[isrc])
test_groundtruths = [tag=="s" for tag in test_groundtruths]
precision, recall, _ = precision_recall_curve(test_groundtruths, predictions)
plt.plot(recall, precision, label=item[:-4] + " (" + str(round(average_precision_score(test_groundtruths, predictions), 3)) + ")")
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.ylim([0.0, 1.05])
plt.xlim([-0.05, 1.05])
plt.title('Precision-Recall curve for Algo (AUC)')
plt.legend(loc='best')
plt.savefig(outdir + "precision_recall.png", dpi=200, bbox_inches="tight")
# plt.show()
plt.close()
utils.print_success("Precision-Recall curve created in " + outdir)
评论列表
文章目录