test_model_selection.py 文件源码

python
阅读 40 收藏 0 点赞 0 评论 0

项目:dask-searchcv 作者: dask 项目源码 文件源码
def test_feature_union(weights):
    X = np.ones((10, 5))
    y = np.zeros(10)

    union = FeatureUnion([('tr0', ScalingTransformer()),
                          ('tr1', ScalingTransformer()),
                          ('tr2', ScalingTransformer())])

    factors = [(2, 3, 5), (2, 4, 5), (2, 4, 6),
               (2, 4, None), (None, None, None)]
    params, sols, grid = [], [], []
    for constants, w in product(factors, weights or [None]):
        p = {}
        for n, c in enumerate(constants):
            if c is None:
                p['tr%d' % n] = None
            elif n == 3:  # 3rd is always an estimator
                p['tr%d' % n] = ScalingTransformer(c)
            else:
                p['tr%d__factor' % n] = c
        sol = union.set_params(transformer_weights=w, **p).transform(X)
        sols.append(sol)
        if w is not None:
            p['transformer_weights'] = w
        params.append(p)
        p2 = {'union__' + k: [v] for k, v in p.items()}
        p2['est'] = [CheckXClassifier(sol[0])]
        grid.append(p2)

    # Need to recreate the union after setting estimators to `None` above
    union = FeatureUnion([('tr0', ScalingTransformer()),
                          ('tr1', ScalingTransformer()),
                          ('tr2', ScalingTransformer())])

    pipe = Pipeline([('union', union), ('est', CheckXClassifier())])
    gs = dcv.GridSearchCV(pipe, grid, refit=False, cv=2)

    with warnings.catch_warnings(record=True):
        gs.fit(X, y)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号