def plot_roc_curve(true_y, prob_y, out_file=None):
from sklearn.metrics import roc_curve
fpr, tpr, _ = roc_curve(true_y, prob_y)
fig = plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve')
plt.plot([0, 1], [0, 1], color='navy', lw=1, linestyle='--')
plt.xlim([-0.025, 1.025])
plt.ylim([-0.025, 1.025])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('RoC Curve')
if out_file is not None:
fig.savefig(out_file)
return fig
评论列表
文章目录