3.Lasso regression.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:ML-note 作者: JasonK93 项目源码 文件源码
def test_Lasso_alpha(*data):
    '''
    test the score with different alpha
    :param data: train_data, test_data, train_value, test_value
    :return: None
    '''

    X_train,X_test,y_train,y_test=data
    alphas=[0.01,0.02,0.05,0.1,0.2,0.5,1,2,5,10,20,50,100,200,500,1000]
    scores=[]
    for i,alpha in enumerate(alphas):
        regr = linear_model.Lasso(alpha=alpha)
        regr.fit(X_train, y_train)
        scores.append(regr.score(X_test, y_test))
    ## graph
    fig=plt.figure()
    ax=fig.add_subplot(1,1,1)
    ax.plot(alphas,scores)
    ax.set_xlabel(r"$\alpha$")
    ax.set_ylabel(r"score")
    ax.set_xscale('log')
    ax.set_title("Lasso")
    plt.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号