def plot_reliability_diagram(y,x,bins=np.linspace(0,1,21),size_points=True, show_baseline=True,ax=None, marker='+',c='red', **kwargs):
if ax is None:
ax = _gca()
fig = ax.get_figure()
digitized_x = np.digitize(x, bins)
mean_count_array = np.array([[np.mean(y[digitized_x == i]),len(y[digitized_x == i]),np.mean(x[digitized_x==i])] for i in np.unique(digitized_x)])
if show_baseline:
ax.plot(np.linspace(0,1,100),(np.linspace(0,1,100)),'k--')
for i in range(len(mean_count_array[:,0])):
if size_points:
plt.scatter(mean_count_array[i,2],mean_count_array[i,0],s=mean_count_array[i,1],marker=marker,c=c, **kwargs)
else:
plt.scatter(mean_count_array[i,2],mean_count_array[i,0], **kwargs)
plt.axis([-0.1,1.1,-0.1,1.1])
return(mean_count_array[:,2],mean_count_array[:,0],mean_count_array[:,1])
评论列表
文章目录