recognition_utils.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:pybot 作者: spillai 项目源码 文件源码
def plot_roc(y_score, y_test, target_map, title='ROC curve'): 
    import matplotlib.pyplot as plt
    from sklearn.metrics import roc_curve, auc, precision_recall_curve
    from sklearn.preprocessing import label_binarize

    # Compute Precision-Recall and plot curve
    fpr = dict()
    tpr = dict()
    roc_auc = dict()

    target_ids = target_map.keys()
    target_names = target_map.values()
    print target_names

    y_test_multi = label_binarize(y_test, classes=target_ids)
    N, n_classes = y_score.shape[:2]
    for i,name in enumerate(target_names):
        fpr[name], tpr[name], _ = roc_curve(y_test_multi[:, i], y_score[:, i])
        roc_auc[name] = auc(fpr[name], tpr[name]) 

    # Compute micro-average ROC curve and ROC area
    fpr["micro"], tpr["micro"], _ = roc_curve(y_test_multi.ravel(), y_score.ravel())
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) 

    # Plot Precision-Recall curve for each class
    plt.clf()
    plt.plot([0, 1], [0, 1], 'k--')
    plt.plot(fpr["micro"], tpr["micro"],
             label='ROC curve (area = {0:0.2f})'
                   ''.format(roc_auc["micro"]), linewidth=3)
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.ylim([0.0, 1.0])
    plt.xlim([0.0, 1.0])
    plt.legend(loc="lower right")
    plt.show()

    for i,name in enumerate(target_names):
        plt.plot(fpr[name], tpr[name],
                 label='{0}'.format(name.title().replace('_', ' ')))
                 # label='{0} (area = {1:0.2f})'
                 #       ''.format(name.title().replace('_', ' '), roc_auc[name]))

    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.0])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(title)
    plt.legend(loc="lower right")
    plt.show(block=False)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号