plots.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:guacml 作者: guacml 项目源码 文件源码
def predictions_vs_actual_regression(model_results, model_name, size=6, bins=None,
                                     gridsize=30, outlier_ratio=None, **kwargs):
    holdout = model_results.holdout_data
    target = model_results.target

    if outlier_ratio is not None:
        holdout = utils.remove_outlier_rows(holdout, 'prediction', outlier_ratio)
        holdout = utils.remove_outlier_rows(holdout, target, outlier_ratio)

    sns.set(style="white", color_codes=True)

    marginal_kws = dict(hist_kws=dict(edgecolor='black'))
    plt.suptitle('{0}: Predictions vs Actual'.format(model_name), fontsize=14)
    grid = sns.jointplot('prediction', target, holdout, 'hexbin', gridsize=gridsize,
                         size=size, bins=bins, space=0, marginal_kws=marginal_kws, **kwargs)
    plt.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1)  # shrink fig so cbar is visible
    cax = grid.fig.add_axes([.95, .18, .04, .5])  # x, y, width, height
    color_bar = sns.plt.colorbar(cax=cax)

    if bins is None:
        color_bar.set_label('count')
    elif bins == 'log':
        color_bar.set_label('log_10(count)')
    return grid
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号