def test_logistic_regression_solvers_multiclass():
X, y = make_classification(n_samples=20, n_features=20, n_informative=10,
n_classes=3, random_state=0)
tol = 1e-6
ncg = LogisticRegression(solver='newton-cg', fit_intercept=False, tol=tol)
lbf = LogisticRegression(solver='lbfgs', fit_intercept=False, tol=tol)
lib = LogisticRegression(fit_intercept=False, tol=tol)
sag = LogisticRegression(solver='sag', fit_intercept=False, tol=tol,
max_iter=1000, random_state=42)
ncg.fit(X, y)
lbf.fit(X, y)
sag.fit(X, y)
lib.fit(X, y)
assert_array_almost_equal(ncg.coef_, lib.coef_, decimal=4)
assert_array_almost_equal(lib.coef_, lbf.coef_, decimal=4)
assert_array_almost_equal(ncg.coef_, lbf.coef_, decimal=4)
assert_array_almost_equal(sag.coef_, lib.coef_, decimal=4)
assert_array_almost_equal(sag.coef_, ncg.coef_, decimal=4)
assert_array_almost_equal(sag.coef_, lbf.coef_, decimal=4)
评论列表
文章目录