segmentation.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:crankshaft 作者: CartoDB 项目源码 文件源码
def train_model(target, features, model_params, test_split):
    """
        Train the Gradient Boosting model on the provided data and calculate the accuracy of the model
        Input:
            @param target: 1D Array of the variable that the model is to be trianed to predict
            @param features: 2D Array NSamples * NFeatures to use in trining the model
            @param model_params: A dictionary of model parameters, the full specification can be found on the
                scikit learn page for [GradientBoostingRegressor](http://scikit-learn.org/stable/modules/generated/sklearn.ensemble.GradientBoostingRegressor.html)
            @parma test_split: The fraction of the data to be withheld for testing the model / calculating the accuray
    """
    features_train, features_test, target_train, target_test = train_test_split(features, target, test_size=test_split)
    model = GradientBoostingRegressor(**model_params)
    model.fit(features_train, target_train)
    accuracy = calculate_model_accuracy(model, features, target)
    return model, accuracy
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号