test_base.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:Parallel-SGD 作者: angadgill 项目源码 文件源码
def test_score_sample_weight():
    from sklearn.tree import DecisionTreeClassifier
    from sklearn.tree import DecisionTreeRegressor
    from sklearn import datasets

    rng = np.random.RandomState(0)

    # test both ClassifierMixin and RegressorMixin
    estimators = [DecisionTreeClassifier(max_depth=2),
                  DecisionTreeRegressor(max_depth=2)]
    sets = [datasets.load_iris(),
            datasets.load_boston()]

    for est, ds in zip(estimators, sets):
        est.fit(ds.data, ds.target)
        # generate random sample weights
        sample_weight = rng.randint(1, 10, size=len(ds.target))
        # check that the score with and without sample weights are different
        assert_not_equal(est.score(ds.data, ds.target),
                         est.score(ds.data, ds.target,
                                   sample_weight=sample_weight),
                         msg="Unweighted and weighted scores "
                             "are unexpectedly equal")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号