def plot_precision_recall_n(y_true, y_prob, model_name, pdf=None):
y_score = y_prob
precision_curve, recall_curve, pr_thresholds = precision_recall_curve(
y_true, y_score)
precision_curve = precision_curve[:-1]
recall_curve = recall_curve[:-1]
pct_above_per_thresh = []
number_scored = len(y_score)
for value in pr_thresholds:
num_above_thresh = len(y_score[y_score >= value])
pct_above_thresh = num_above_thresh / float(number_scored)
pct_above_per_thresh.append(pct_above_thresh)
pct_above_per_thresh = np.array(pct_above_per_thresh)
plt.clf()
fig, ax1 = plt.subplots()
ax1.plot(pct_above_per_thresh, precision_curve, 'b')
ax1.set_xlabel('percent of population')
ax1.set_ylabel('precision', color='b')
ax2 = ax1.twinx()
ax2.plot(pct_above_per_thresh, recall_curve, 'r')
ax2.set_ylabel('recall', color='r')
name = model_name
plt.title(name)
if pdf:
pdf.savefig()
plt.close()
else:
plt.show()
评论列表
文章目录