test_optimisation.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:uncover-ml 作者: GeoscienceAustralia 项目源码 文件源码
def test_svr_pipeline(get_transform, get_svr_kernel):
    trans = get_transform()
    pipe = Pipeline(steps=[('svr', svr())])
    param_dict = {'svr__kernel': [get_svr_kernel]}
    param_dict['svr__target_transform'] = [trans]

    estimator = GridSearchCV(pipe,
                             param_dict,
                             n_jobs=1,
                             iid=False,
                             pre_dispatch=2,
                             verbose=True,
                             )
    np.random.seed(1)
    estimator.fit(X=1 + np.random.rand(10, 5), y=1. + np.random.rand(10))
    assert estimator.cv_results_['mean_train_score'][0] > -10.0
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号