def test_base_estimator():
# Check base_estimator and its default values.
rng = check_random_state(0)
# Classification
X_train, X_test, y_train, y_test = train_test_split(iris.data,
iris.target,
random_state=rng)
ensemble = BaggingClassifier(None,
n_jobs=3,
random_state=0).fit(X_train, y_train)
assert_true(isinstance(ensemble.base_estimator_, DecisionTreeClassifier))
ensemble = BaggingClassifier(DecisionTreeClassifier(),
n_jobs=3,
random_state=0).fit(X_train, y_train)
assert_true(isinstance(ensemble.base_estimator_, DecisionTreeClassifier))
ensemble = BaggingClassifier(Perceptron(),
n_jobs=3,
random_state=0).fit(X_train, y_train)
assert_true(isinstance(ensemble.base_estimator_, Perceptron))
# Regression
X_train, X_test, y_train, y_test = train_test_split(boston.data,
boston.target,
random_state=rng)
ensemble = BaggingRegressor(None,
n_jobs=3,
random_state=0).fit(X_train, y_train)
assert_true(isinstance(ensemble.base_estimator_, DecisionTreeRegressor))
ensemble = BaggingRegressor(DecisionTreeRegressor(),
n_jobs=3,
random_state=0).fit(X_train, y_train)
assert_true(isinstance(ensemble.base_estimator_, DecisionTreeRegressor))
ensemble = BaggingRegressor(SVR(),
n_jobs=3,
random_state=0).fit(X_train, y_train)
assert_true(isinstance(ensemble.base_estimator_, SVR))
评论列表
文章目录