test_model_selection.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:dask-searchcv 作者: dask 项目源码 文件源码
def test_pipeline_feature_union():
    iris = load_iris()
    X, y = iris.data, iris.target

    pca = PCA(random_state=0)
    kbest = SelectKBest()
    empty_union = FeatureUnion([('first', None), ('second', None)])
    empty_pipeline = Pipeline([('first', None), ('second', None)])
    scaling = Pipeline([('transform', ScalingTransformer())])
    svc = SVC(kernel='linear', random_state=0)

    pipe = Pipeline([('empty_pipeline', empty_pipeline),
                     ('scaling', scaling),
                     ('missing', None),
                     ('union', FeatureUnion([('pca', pca),
                                             ('missing', None),
                                             ('kbest', kbest),
                                             ('empty_union', empty_union)],
                                            transformer_weights={'pca': 0.5})),
                     ('svc', svc)])

    param_grid = dict(scaling__transform__factor=[1, 2],
                      union__pca__n_components=[1, 2, 3],
                      union__kbest__k=[1, 2],
                      svc__C=[0.1, 1, 10])

    gs = GridSearchCV(pipe, param_grid=param_grid)
    gs.fit(X, y)
    dgs = dcv.GridSearchCV(pipe, param_grid=param_grid, scheduler='sync')
    dgs.fit(X, y)

    # Check best params match
    assert gs.best_params_ == dgs.best_params_

    # Check PCA components match
    sk_pca = gs.best_estimator_.named_steps['union'].transformer_list[0][1]
    dk_pca = dgs.best_estimator_.named_steps['union'].transformer_list[0][1]
    np.testing.assert_allclose(sk_pca.components_, dk_pca.components_)

    # Check SelectKBest scores match
    sk_kbest = gs.best_estimator_.named_steps['union'].transformer_list[2][1]
    dk_kbest = dgs.best_estimator_.named_steps['union'].transformer_list[2][1]
    np.testing.assert_allclose(sk_kbest.scores_, dk_kbest.scores_)

    # Check SVC coefs match
    np.testing.assert_allclose(gs.best_estimator_.named_steps['svc'].coef_,
                               dgs.best_estimator_.named_steps['svc'].coef_)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号