def plot_roc_per_class(model, test_data, test_truth, labels, title):
# Compute macro-average ROC curve and ROC area
fpr, tpr, roc_auc = get_fpr_tpr_roc(model, test_data, test_truth, labels)
# First aggregate all false positive rates
n_classes = len(labels)
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
mean_tpr += interp(all_fpr, fpr[i], tpr[i])
# Finally average it and compute AUC
mean_tpr /= n_classes
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
# Plot all ROC curves
lw = 2
plt.figure(figsize=(20,16))
colors = cycle(['aqua', 'darkorange', 'cornflowerblue', 'green', 'pink', 'magenta', 'grey', 'purple'])
idx = 0
for key, color in zip(labels.keys(), colors):
plt.plot( fpr[idx], tpr[idx], color=color, lw=lw, label='ROC curve of class '+str(key) )
idx += 1
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC:'+ str(labels) + '\n' + title)
plt.legend(loc="lower right")
plt.savefig("./per_class_roc_"+title+".jpg")
chrom_hmm_cnn.py 文件源码
python
阅读 41
收藏 0
点赞 0
评论 0
评论列表
文章目录