def compute_pr(y_test, probability_predictions):
"""
Compute Precision-Recall, thresholds and PR AUC.
Args:
y_test (list) : true label values corresponding to the predictions. Also length n.
probability_predictions (list) : predictions coming from an ML algorithm of length n.
Returns:
dict:
"""
_validate_predictions_and_labels_are_equal_length(probability_predictions, y_test)
# Calculate PR
precisions, recalls, pr_thresholds = skmetrics.precision_recall_curve(y_test, probability_predictions)
pr_auc = skmetrics.average_precision_score(y_test, probability_predictions)
# get ideal cutoffs for suggestions (upper right or 1,1)
pr_distances = (precisions - 1) ** 2 + (recalls - 1) ** 2
# To prevent the case where there are two points with the same minimum distance, return only the first
# np.where returns a tuple (we want the first element in the first array)
pr_index = np.where(pr_distances == np.min(pr_distances))[0][0]
best_precision = precisions[pr_index]
best_recall = recalls[pr_index]
ideal_pr_cutoff = pr_thresholds[pr_index]
return {'pr_auc': pr_auc,
'best_pr_cutoff': ideal_pr_cutoff,
'best_precision': best_precision,
'best_recall': best_recall,
'precisions': precisions,
'recalls': recalls,
'pr_thresholds': pr_thresholds}
评论列表
文章目录