calibration_utils.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:introspective 作者: numeristical 项目源码 文件源码
def plot_reliability_diagram(y,x,bins=np.linspace(0,1,21),size_points=True, show_baseline=True,ax=None, marker='+',c='red', **kwargs):
    if ax is None:
        ax = _gca()
        fig = ax.get_figure()
    digitized_x = np.digitize(x, bins)
    mean_count_array = np.array([[np.mean(y[digitized_x == i]),len(y[digitized_x == i]),np.mean(x[digitized_x==i])] for i in np.unique(digitized_x)])
    if show_baseline:
        ax.plot(np.linspace(0,1,100),(np.linspace(0,1,100)),'k--')
    for i in range(len(mean_count_array[:,0])):
        if size_points:
            plt.scatter(mean_count_array[i,2],mean_count_array[i,0],s=mean_count_array[i,1],marker=marker,c=c, **kwargs)
        else: 
            plt.scatter(mean_count_array[i,2],mean_count_array[i,0], **kwargs)
    plt.axis([-0.1,1.1,-0.1,1.1])
    return(mean_count_array[:,2],mean_count_array[:,0],mean_count_array[:,1])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号