def wage_data_linear():
X, y = wage()
gam = LinearGAM(n_splines=10)
gam.gridsearch(X, y, lam=np.logspace(-5,3,50))
XX = generate_X_grid(gam)
plt.figure()
fig, axs = plt.subplots(1,3)
titles = ['year', 'age', 'education']
for i, ax in enumerate(axs):
ax.plot(XX[:, i], gam.partial_dependence(XX, feature=i+1))
ax.plot(XX[:, i], *gam.partial_dependence(XX, feature=i+1, width=.95)[1],
c='r', ls='--')
if i == 0:
ax.set_ylim(-30,30);
ax.set_title(titles[i])
fig.tight_layout()
plt.savefig('imgs/pygam_wage_data_linear.png', dpi=300)
评论列表
文章目录