test_multioutput.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:Parallel-SGD 作者: angadgill 项目源码 文件源码
def test_multi_target_sample_weights():
    # weighted regressor
    Xw = [[1,2,3], [4,5,6]]
    yw = [[3.141, 2.718], [2.718, 3.141]]
    w = [2., 1.]
    rgr_w = MultiOutputRegressor(GradientBoostingRegressor(random_state=0))
    rgr_w.fit(Xw, yw, w)

    # unweighted, but with repeated samples
    X = [[1,2,3], [1,2,3], [4,5,6]]
    y = [[3.141, 2.718], [3.141, 2.718], [2.718, 3.141]]
    rgr = MultiOutputRegressor(GradientBoostingRegressor(random_state=0))
    rgr.fit(X, y)

    X_test = [[1.5,2.5,3.5], [3.5,4.5,5.5]]
    assert_almost_equal(rgr.predict(X_test), rgr_w.predict(X_test))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号