recognition_utils.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:pybot 作者: spillai 项目源码 文件源码
def multilabel_precision_recall(y_score, y_test, clf_target_ids, clf_target_names): 
    from sklearn.metrics import precision_recall_curve
    from sklearn.metrics import average_precision_score
    from sklearn.preprocessing import label_binarize

    # Compute Precision-Recall and plot curve
    precision = dict()
    recall = dict()
    average_precision = dict()

    # Find indices that have non-zero detections
    clf_target_map = { k: v for k,v in zip(clf_target_ids, clf_target_names)}
    id2ind = {tid: idx for (idx,tid) in enumerate(clf_target_ids)}

    # Only handle the targets encountered
    unique = np.unique(y_test)
    nzinds = np.int64([id2ind[target] for target in unique])

    # Binarize and create precision-recall curves
    y_test_multi = label_binarize(y_test, classes=unique)
    for i,target in enumerate(unique):
        index = id2ind[target]
        name = clf_target_map[target]
        precision[name], recall[name], _ = precision_recall_curve(y_test_multi[:, i],
                                                                  y_score[:, index])
        average_precision[name] = average_precision_score(y_test_multi[:, i], y_score[:, index])

    # Compute micro-average ROC curve and ROC area
    precision["average"], recall["average"], _ = precision_recall_curve(y_test_multi.ravel(),
                                                                        y_score[:,nzinds].ravel())
    average_precision["micro"] = average_precision_score(y_test_multi, y_score[:,nzinds],
                                                         average="micro") 
    average_precision["macro"] = average_precision_score(y_test_multi, y_score[:,nzinds],
                                                         average="macro") 
    return precision, recall, average_precision
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号