def save_prcurve(prob, answer, model_name, save_fn, use_neg=True):
"""
save prc curve
"""
if not use_neg:
prob_dn = []
ans_dn = []
for p in prob:
prob_dn.append(p[1:])
for ans in answer:
ans_dn.append(ans[1:])
prob = np.reshape(np.array(prob_dn), (-1))
ans = np.reshape(np.array(ans_dn), (-1))
else:
prob = np.reshape(prob, (-1))
ans = np.reshape(answer, (-1))
precision, recall, threshold = precision_recall_curve(ans, prob)
average_precision = average_precision_score(ans, prob)
plt.clf()
plt.plot(recall[:], precision[:], lw=2, color='navy', label=model_name)
plt.xlabel('Recall')
plt.ylabel('Precision')
# plt.ylim([0.3, 1.0])
# plt.xlim([0.0, 0.4])
plt.title('Precision-Recall Area={0:0.2f}'.format(average_precision))
plt.legend(loc="upper right")
plt.grid(True)
plt.savefig(save_fn)
评论列表
文章目录