test_logistic.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:Parallel-SGD 作者: angadgill 项目源码 文件源码
def test_logistic_regression_solvers_multiclass():
    X, y = make_classification(n_samples=20, n_features=20, n_informative=10,
                               n_classes=3, random_state=0)
    tol = 1e-6
    ncg = LogisticRegression(solver='newton-cg', fit_intercept=False, tol=tol)
    lbf = LogisticRegression(solver='lbfgs', fit_intercept=False, tol=tol)
    lib = LogisticRegression(fit_intercept=False, tol=tol)
    sag = LogisticRegression(solver='sag', fit_intercept=False, tol=tol,
                             max_iter=1000, random_state=42)
    ncg.fit(X, y)
    lbf.fit(X, y)
    sag.fit(X, y)
    lib.fit(X, y)
    assert_array_almost_equal(ncg.coef_, lib.coef_, decimal=4)
    assert_array_almost_equal(lib.coef_, lbf.coef_, decimal=4)
    assert_array_almost_equal(ncg.coef_, lbf.coef_, decimal=4)
    assert_array_almost_equal(sag.coef_, lib.coef_, decimal=4)
    assert_array_almost_equal(sag.coef_, ncg.coef_, decimal=4)
    assert_array_almost_equal(sag.coef_, lbf.coef_, decimal=4)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号