regularize.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:DSI-personal-reference-kit 作者: teb311 项目源码 文件源码
def fit_regression(X, y, regression_class=LinearRegression, regularization_const=.001):
    '''
        Given a dataset and some solutions (X, y) a regression class (from scikit learn)
        and an Lambda which is required if the regression class is Lasso or Ridge

        X (pandas DataFrame): The data.
        y (pandas DataFrame or Series): The answers.
        regression_class (class): One of sklearn.linear_model.[LinearRegression, Ridge, Lasso]
        regularization_const: the regularization_const value (regularization parameter) for Ridge or Lasso.
                              Called alpha by scikit learn for interface reasons.

        Return:
            tuple, (the_fitted_regressor, mean(cross_val_score)).
    '''
    if regression_class is LinearRegression:
        predictor = regression_class()
    else:
        predictor = regression_class(alpha=regularization_const, normalize=True)

    predictor.fit(X, y)

    cross_scores = cross_val_score(predictor, X, y=y, scoring='neg_mean_squared_error')
    cross_scores_corrected = np.sqrt(-1 * cross_scores)  # Scikit learn returns negative vals && we need root

    return (predictor, np.mean(cross_scores_corrected))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号