def plot_PR_by_class(y_pred, y_true, classes, out_path):
best_thresh = {}
for class_name, c in classes.items(): # for each class
# Compute ROC curve
precision, recall, thresholds = precision_recall_curve(y_true[:, c], y_pred[:, c])
pr_auc = auc(recall, precision)
# Plot PR curve
plt.plot(recall, precision, label='{}, AUC = {:.3f}'.format(class_name, pr_auc))
# Calculate J statistic
J = [j_statistic(y_true, y_pred, t) for t in thresholds]
j_best = np.argmax(J)
# Store best threshold for each class
best_thresh[class_name] = J[j_best]
return best_thresh
评论列表
文章目录