model_eval.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:healthcareai-py 作者: HealthCatalyst 项目源码 文件源码
def roc_plot_from_thresholds(roc_thresholds_by_model, save=False, debug=False):
    """
    From a given dictionary of thresholds by model, create a ROC curve for each model.

    Args:
        roc_thresholds_by_model (dict): A dictionary of ROC thresholds by model name.
        save (bool): False to display the image (default) or True to save it (but not display it)
        debug (bool): verbost output.
    """
    # TODO consolidate this and PR plotter into 1 function
    # TODO make the colors randomly generated from rgb values
    # Cycle through the colors list
    color_iterator = itertools.cycle(['b', 'g', 'r', 'c', 'm', 'y', 'k'])
    # Initialize plot
    plt.figure()
    plt.xlabel('False Positive Rate (FPR)')
    plt.ylabel('True Positive Rate (TRP)')
    plt.title('Receiver Operating Characteristic (ROC)')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.plot([0, 1], [0, 1], linestyle=DIAGONAL_LINE_STYLE, color=DIAGONAL_LINE_COLOR)

    # Calculate and plot for each model
    for color, (model_name, metrics) in zip(color_iterator, roc_thresholds_by_model.items()):
        # Extract model name and metrics from dictionary
        roc_auc = metrics['roc_auc']
        tpr = metrics['true_positive_rates']
        fpr = metrics['false_positive_rates']
        best_true_positive_rate = metrics['best_true_positive_rate']
        best_false_positive_rate = metrics['best_false_positive_rate']

        if debug:
            print('{} model:'.format(model_name))
            print(pd.DataFrame({'FPR': fpr, 'TPR': tpr}))

        # plot the line
        label = '{} (ROC AUC = {})'.format(model_name, round(roc_auc, 2))
        plt.plot(fpr, tpr, color=color, label=label)
        plt.plot([best_false_positive_rate], [best_true_positive_rate], marker='*', markersize=10, color=color)

    plt.legend(loc="lower right")

    if save:
        plt.savefig('ROC.png')
        source_path = os.path.dirname(os.path.abspath(__file__))
        print('\nROC plot saved in: {}'.format(source_path))

    plt.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号