test_bagging.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:Parallel-SGD 作者: angadgill 项目源码 文件源码
def test_sparse_classification():
    # Check classification for various parameter settings on sparse input.

    class CustomSVC(SVC):
        """SVC variant that records the nature of the training set"""

        def fit(self, X, y):
            super(CustomSVC, self).fit(X, y)
            self.data_type_ = type(X)
            return self

    rng = check_random_state(0)
    X_train, X_test, y_train, y_test = train_test_split(iris.data,
                                                        iris.target,
                                                        random_state=rng)
    parameter_sets = [
        {"max_samples": 0.5,
         "max_features": 2,
         "bootstrap": True,
         "bootstrap_features": True},
        {"max_samples": 1.0,
         "max_features": 4,
         "bootstrap": True,
         "bootstrap_features": True},
        {"max_features": 2,
         "bootstrap": False,
         "bootstrap_features": True},
        {"max_samples": 0.5,
         "bootstrap": True,
         "bootstrap_features": False},
    ]

    for sparse_format in [csc_matrix, csr_matrix]:
        X_train_sparse = sparse_format(X_train)
        X_test_sparse = sparse_format(X_test)
        for params in parameter_sets:
            for f in ['predict', 'predict_proba', 'predict_log_proba', 'decision_function']:
                # Trained on sparse format
                sparse_classifier = BaggingClassifier(
                    base_estimator=CustomSVC(decision_function_shape='ovr'),
                    random_state=1,
                    **params
                ).fit(X_train_sparse, y_train)
                sparse_results = getattr(sparse_classifier, f)(X_test_sparse)

                # Trained on dense format
                dense_classifier = BaggingClassifier(
                    base_estimator=CustomSVC(decision_function_shape='ovr'),
                    random_state=1,
                    **params
                ).fit(X_train, y_train)
                dense_results = getattr(dense_classifier, f)(X_test)
                assert_array_equal(sparse_results, dense_results)

            sparse_type = type(X_train_sparse)
            types = [i.data_type_ for i in sparse_classifier.estimators_]

            assert all([t == sparse_type for t in types])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号