model_eval.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:healthcareai-py 作者: HealthCatalyst 项目源码 文件源码
def pr_plot_from_thresholds(pr_thresholds_by_model, save=False, debug=False):
    """
    From a given dictionary of thresholds by model, create a PR curve for each model.

    Args:
        pr_thresholds_by_model (dict): A dictionary of PR thresholds by model name.
        save (bool): False to display the image (default) or True to save it (but not display it)
        debug (bool): verbost output.
    """
    # TODO consolidate this and PR plotter into 1 function
    # TODO make the colors randomly generated from rgb values
    # Cycle through the colors list
    color_iterator = itertools.cycle(['b', 'g', 'r', 'c', 'm', 'y', 'k'])
    # Initialize plot
    plt.figure()
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision Recall (PR)')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.plot([0, 1], [1, 0], linestyle=DIAGONAL_LINE_STYLE, color=DIAGONAL_LINE_COLOR)

    # Calculate and plot for each model
    for color, (model_name, metrics) in zip(color_iterator, pr_thresholds_by_model.items()):
        # Extract model name and metrics from dictionary
        pr_auc = metrics['pr_auc']
        precision = metrics['precisions']
        recall = metrics['recalls']
        best_recall = metrics['best_recall']
        best_precision = metrics['best_precision']

        if debug:
            print('{} model:'.format(model_name))
            print(pd.DataFrame({'Recall': recall, 'Precision': precision}))

        # plot the line
        label = '{} (PR AUC = {})'.format(model_name, round(pr_auc, 2))
        plt.plot(recall, precision, color=color, label=label)
        plt.plot([best_recall], [best_precision], marker='*', markersize=10, color=color)

    plt.legend(loc="lower left")

    if save:
        plt.savefig('PR.png')
        source_path = os.path.dirname(os.path.abspath(__file__))
        print('\nPR plot saved in: {}'.format(source_path))

    plt.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号