test_base.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:l1l2py 作者: slipguru 项目源码 文件源码
def test_ridge():
    """Test Ridge regression for different values of mu."""
    # A simple sum function (with intercept)
    X = [[1, 2], [3, 4], [5, 6]]
    y = [sum(x)+1 for x in X]
    T = [[7, 8], [9, 10], [2, 1]]

    model = RidgeRegression(mu=0.0).fit(X, y)  # OLS
    assert_array_almost_equal([1, 1], model.coef_)
    assert_array_almost_equal([16, 20, 4], model.predict(T))
    assert_almost_equal(1.0, model.intercept_)

    # Equivalence with standard numpy least squares
    Xc = X - np.mean(X, axis=0)
    assert_almost_equal(la.lstsq(Xc, y)[0], model.coef_)

    model = RidgeRegression(mu=0.5).fit(X, y)
    assert_array_almost_equal([0.91428571, 0.91428571], model.coef_)
    assert_array_almost_equal([15.31428571, 18.97142857, 4.34285714],
                              model.predict(T))

    model = RidgeRegression(mu=1.0).fit(X, y)
    assert_array_almost_equal([0.84210526, 0.84210526], model.coef_)
    assert_array_almost_equal([14.73684211, 18.10526316, 4.63157895],
                              model.predict(T))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号