def test_warm_start_oob():
# Test if warm start OOB equals fit.
X, y = datasets.make_hastie_10_2(n_samples=100, random_state=1)
for Cls in [GradientBoostingRegressor, GradientBoostingClassifier]:
est = Cls(n_estimators=200, max_depth=1, subsample=0.5,
random_state=1)
est.fit(X, y)
est_ws = Cls(n_estimators=100, max_depth=1, subsample=0.5,
random_state=1, warm_start=True)
est_ws.fit(X, y)
est_ws.set_params(n_estimators=200)
est_ws.fit(X, y)
assert_array_almost_equal(est_ws.oob_improvement_[:100],
est.oob_improvement_[:100])
评论列表
文章目录