regressor.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:EarlyWarning 作者: wjlei1990 项目源码 文件源码
def train_lasso_model(_train_x, train_y, _predict_x):
    print_title("Lasso Regressor")

    train_x, predict_x = \
        standarize_feature(_train_x, _predict_x)

    reg = linear_model.LassoCV(
        precompute=True, cv=5, verbose=1, n_jobs=4)
    reg.fit(train_x, train_y)
    print("alphas: %s" % reg.alphas_)
    print("mse path: %s" % np.mean(reg.mse_path_, axis=1))

    itemindex = np.where(reg.alphas_ == reg.alpha_)
    print("itemindex: %s" % itemindex)
    _mse = np.mean(reg.mse_path_[itemindex[0], :])
    print("Best alpha using bulit-in LassoCV: %f(mse: %f)" %
          (reg.alpha_, _mse))

    alpha = reg.alpha_
    reg = linear_model.Lasso(alpha=alpha)
    reg.fit(train_x, train_y)
    n_nonzeros = (reg.coef_ != 0).sum()
    print("Non-zeros coef: %d" % n_nonzeros)
    predict_y = reg.predict(predict_x)
    train_y_pred = reg.predict(train_x)

    return {"y": predict_y, "train_y": train_y_pred, "coef": reg.coef_}
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号