def test_gbrt_base_estimator():
rng = np.random.RandomState(1)
N = 10000
X = np.ones((N, 1))
y = rng.normal(size=N)
base = RandomForestRegressor()
rgr = GradientBoostingQuantileRegressor(base_estimator=base)
assert_raise_message(ValueError, 'type GradientBoostingRegressor',
rgr.fit, X, y)
base = GradientBoostingRegressor()
rgr = GradientBoostingQuantileRegressor(base_estimator=base)
assert_raise_message(ValueError, 'quantile loss', rgr.fit, X, y)
base = GradientBoostingRegressor(loss='quantile', n_estimators=20)
rgr = GradientBoostingQuantileRegressor(base_estimator=base)
rgr.fit(X, y)
estimates = rgr.predict(X, return_quantiles=True)
assert_almost_equal(stats.norm.ppf(rgr.quantiles),
np.mean(estimates, axis=0),
decimal=2)
评论列表
文章目录