def roc_plot_from_thresholds(roc_thresholds_by_model, save=False, debug=False):
"""
From a given dictionary of thresholds by model, create a ROC curve for each model.
Args:
roc_thresholds_by_model (dict): A dictionary of ROC thresholds by model name.
save (bool): False to display the image (default) or True to save it (but not display it)
debug (bool): verbost output.
"""
# TODO consolidate this and PR plotter into 1 function
# TODO make the colors randomly generated from rgb values
# Cycle through the colors list
color_iterator = itertools.cycle(['b', 'g', 'r', 'c', 'm', 'y', 'k'])
# Initialize plot
plt.figure()
plt.xlabel('False Positive Rate (FPR)')
plt.ylabel('True Positive Rate (TRP)')
plt.title('Receiver Operating Characteristic (ROC)')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.plot([0, 1], [0, 1], linestyle=DIAGONAL_LINE_STYLE, color=DIAGONAL_LINE_COLOR)
# Calculate and plot for each model
for color, (model_name, metrics) in zip(color_iterator, roc_thresholds_by_model.items()):
# Extract model name and metrics from dictionary
roc_auc = metrics['roc_auc']
tpr = metrics['true_positive_rates']
fpr = metrics['false_positive_rates']
best_true_positive_rate = metrics['best_true_positive_rate']
best_false_positive_rate = metrics['best_false_positive_rate']
if debug:
print('{} model:'.format(model_name))
print(pd.DataFrame({'FPR': fpr, 'TPR': tpr}))
# plot the line
label = '{} (ROC AUC = {})'.format(model_name, round(roc_auc, 2))
plt.plot(fpr, tpr, color=color, label=label)
plt.plot([best_false_positive_rate], [best_true_positive_rate], marker='*', markersize=10, color=color)
plt.legend(loc="lower right")
if save:
plt.savefig('ROC.png')
source_path = os.path.dirname(os.path.abspath(__file__))
print('\nROC plot saved in: {}'.format(source_path))
plt.show()
评论列表
文章目录