def plot_ROC_by_class(y_true, y_pred, classes, ls='-'):
print y_true.shape
print y_pred.shape
best_thresh = {}
for class_name, c in classes.items(): # for each class
# Compute ROC curve
fpr, tpr, thresholds = roc_curve(y_true[:, c], y_pred[:, c])
roc_auc = auc(fpr, tpr)
# Plot ROC curve
plt.plot(fpr, tpr, label='{}, AUC = {:.3f}'.format(class_name, roc_auc), linestyle=ls)
# Calculate J statistic
J = [j_statistic(y_true[:, c], y_pred[:, c], t) for t in thresholds]
j_best = np.argmax(J)
# Store best threshold for each class
best_thresh[class_name] = J[j_best]
return best_thresh
评论列表
文章目录