def pr_plot_from_thresholds(pr_thresholds_by_model, save=False, debug=False):
"""
From a given dictionary of thresholds by model, create a PR curve for each model.
Args:
pr_thresholds_by_model (dict): A dictionary of PR thresholds by model name.
save (bool): False to display the image (default) or True to save it (but not display it)
debug (bool): verbost output.
"""
# TODO consolidate this and PR plotter into 1 function
# TODO make the colors randomly generated from rgb values
# Cycle through the colors list
color_iterator = itertools.cycle(['b', 'g', 'r', 'c', 'm', 'y', 'k'])
# Initialize plot
plt.figure()
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision Recall (PR)')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.plot([0, 1], [1, 0], linestyle=DIAGONAL_LINE_STYLE, color=DIAGONAL_LINE_COLOR)
# Calculate and plot for each model
for color, (model_name, metrics) in zip(color_iterator, pr_thresholds_by_model.items()):
# Extract model name and metrics from dictionary
pr_auc = metrics['pr_auc']
precision = metrics['precisions']
recall = metrics['recalls']
best_recall = metrics['best_recall']
best_precision = metrics['best_precision']
if debug:
print('{} model:'.format(model_name))
print(pd.DataFrame({'Recall': recall, 'Precision': precision}))
# plot the line
label = '{} (PR AUC = {})'.format(model_name, round(pr_auc, 2))
plt.plot(recall, precision, color=color, label=label)
plt.plot([best_recall], [best_precision], marker='*', markersize=10, color=color)
plt.legend(loc="lower left")
if save:
plt.savefig('PR.png')
source_path = os.path.dirname(os.path.abspath(__file__))
print('\nPR plot saved in: {}'.format(source_path))
plt.show()
评论列表
文章目录