test_gbrt.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:scikit-optimize 作者: scikit-optimize 项目源码 文件源码
def test_gbrt_base_estimator():
    rng = np.random.RandomState(1)
    N = 10000
    X = np.ones((N, 1))
    y = rng.normal(size=N)

    base = RandomForestRegressor()
    rgr = GradientBoostingQuantileRegressor(base_estimator=base)
    assert_raise_message(ValueError, 'type GradientBoostingRegressor',
                         rgr.fit, X, y)

    base = GradientBoostingRegressor()
    rgr = GradientBoostingQuantileRegressor(base_estimator=base)
    assert_raise_message(ValueError, 'quantile loss', rgr.fit, X, y)

    base = GradientBoostingRegressor(loss='quantile', n_estimators=20)
    rgr = GradientBoostingQuantileRegressor(base_estimator=base)
    rgr.fit(X, y)

    estimates = rgr.predict(X, return_quantiles=True)
    assert_almost_equal(stats.norm.ppf(rgr.quantiles),
                        np.mean(estimates, axis=0),
                        decimal=2)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号