plots.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:guacml 作者: guacml 项目源码 文件源码
def predictions_vs_actual_classification(model_results, model_name, n_bins, figsize=(7, 3)):
    holdout = model_results.holdout_data
    target = model_results.target
    bins = np.arange(0, 1.001, 1 / n_bins)
    bin_mids = (bins[:-1] + bins[1:]) / 2
    binned = pd.cut(holdout['prediction'], bins=bins)
    bin_counts = holdout.groupby(binned)[target].count()
    bin_means = holdout.groupby(binned)[target].mean()

    fig = plt.figure(figsize=figsize)
    plt.suptitle('{0}: Predictions vs Actual'.format(model_name), fontsize=14)
    ax1 = plt.gca()
    ax1.grid(False)
    ax1.bar(bin_mids, bin_counts, width=1/n_bins, color=sns.light_palette('green')[1],
            label='row count', edgecolor='black')
    ax1.set_xlabel('predicted probability')
    ax1.set_ylabel('row count')

    ax2 = ax1.twinx()
    ax2.plot(bin_mids, bin_means, linewidth=3,
             marker='.', markersize=16, label='actual rate')
    ax2.plot(bins, bins, color=sns.color_palette()[2], label='main diagonal')

    ax2.set_ylabel('actual rate')

    handles, labels = ax1.get_legend_handles_labels()
    handles2, labels2 = ax2.get_legend_handles_labels()
    legend = plt.legend(handles + handles2, labels + labels2,
                        loc='best',
                        frameon=True,
                        framealpha=0.7)
    frame = legend.get_frame()
    frame.set_facecolor('white')
    return fig
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号