def test_sparse_classification():
# Check classification for various parameter settings on sparse input.
class CustomSVC(SVC):
"""SVC variant that records the nature of the training set"""
def fit(self, X, y):
super(CustomSVC, self).fit(X, y)
self.data_type_ = type(X)
return self
rng = check_random_state(0)
X_train, X_test, y_train, y_test = train_test_split(iris.data,
iris.target,
random_state=rng)
parameter_sets = [
{"max_samples": 0.5,
"max_features": 2,
"bootstrap": True,
"bootstrap_features": True},
{"max_samples": 1.0,
"max_features": 4,
"bootstrap": True,
"bootstrap_features": True},
{"max_features": 2,
"bootstrap": False,
"bootstrap_features": True},
{"max_samples": 0.5,
"bootstrap": True,
"bootstrap_features": False},
]
for sparse_format in [csc_matrix, csr_matrix]:
X_train_sparse = sparse_format(X_train)
X_test_sparse = sparse_format(X_test)
for params in parameter_sets:
for f in ['predict', 'predict_proba', 'predict_log_proba', 'decision_function']:
# Trained on sparse format
sparse_classifier = BaggingClassifier(
base_estimator=CustomSVC(decision_function_shape='ovr'),
random_state=1,
**params
).fit(X_train_sparse, y_train)
sparse_results = getattr(sparse_classifier, f)(X_test_sparse)
# Trained on dense format
dense_classifier = BaggingClassifier(
base_estimator=CustomSVC(decision_function_shape='ovr'),
random_state=1,
**params
).fit(X_train, y_train)
dense_results = getattr(dense_classifier, f)(X_test)
assert_array_equal(sparse_results, dense_results)
sparse_type = type(X_train_sparse)
types = [i.data_type_ for i in sparse_classifier.estimators_]
assert all([t == sparse_type for t in types])
评论列表
文章目录