def test_oob_score_classification():
# Check that oob prediction is a good estimation of the generalization
# error.
rng = check_random_state(0)
X_train, X_test, y_train, y_test = train_test_split(iris.data,
iris.target,
random_state=rng)
for base_estimator in [DecisionTreeClassifier(), SVC()]:
clf = BaggingClassifier(base_estimator=base_estimator,
n_estimators=100,
bootstrap=True,
oob_score=True,
random_state=rng).fit(X_train, y_train)
test_score = clf.score(X_test, y_test)
assert_less(abs(test_score - clf.oob_score_), 0.1)
# Test with few estimators
assert_warns(UserWarning,
BaggingClassifier(base_estimator=base_estimator,
n_estimators=1,
bootstrap=True,
oob_score=True,
random_state=rng).fit,
X_train,
y_train)
评论列表
文章目录