test_double.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:l1l2py 作者: slipguru 项目源码 文件源码
def _test_double_optimization():
    """Test double optimization on a simple example."""
    # A simple sparse-sum function
    X = [[1, 2], [3, 4], [5, 6]]
    y = [sum(x) for x in X]
    T = [[7, 8], [9, 10], [2, 1]]

    # noisy variables
    np.random.seed(0)
    X = np.c_[X, np.random.random((3, 100))]
    T = np.c_[T, np.random.random((3, 100))]    

    # Select the first 2 variables and calculate a linear model on them
    dstep = DoubleStepEstimator(Lasso(tau=1.0), RidgeRegression(mu=0.0)).train(X, y)

    # Coefficients
    lasso = dstep.selector
    ridge = dstep.estimator
    assert_array_almost_equal([0.90635646, 0.90635646], lasso.beta[:2])
    assert_array_almost_equal([1.0, 1.0], ridge.beta)
    assert_array_almost_equal([1.0, 1.0], dstep.beta[:2])

    # Prediction
    y_ = dstep.predict(T)
    assert_array_almost_equal([15., 19., 3.], y_)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号