gen_imgs.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:pyGAM 作者: dswah 项目源码 文件源码
def wage_data_linear():
    X, y = wage()

    gam = LinearGAM(n_splines=10)
    gam.gridsearch(X, y, lam=np.logspace(-5,3,50))

    XX = generate_X_grid(gam)

    plt.figure()
    fig, axs = plt.subplots(1,3)

    titles = ['year', 'age', 'education']
    for i, ax in enumerate(axs):
        ax.plot(XX[:, i], gam.partial_dependence(XX, feature=i+1))
        ax.plot(XX[:, i], *gam.partial_dependence(XX, feature=i+1, width=.95)[1],
                c='r', ls='--')
        if i == 0:
            ax.set_ylim(-30,30);
        ax.set_title(titles[i])

    fig.tight_layout()
    plt.savefig('imgs/pygam_wage_data_linear.png', dpi=300)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号