tests.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:hoag 作者: OuYag 项目源码 文件源码
def test_LogisticRegressionCV():
    bunch = fetch_20newsgroups_vectorized(subset="train")
    X = bunch.data
    y = bunch.target

    y[y < y.mean()] = -1
    y[y >= y.mean()] = 1
    Xt, Xh, yt, yh = cross_validation.train_test_split(
        X, y, test_size=.5, random_state=0)

    # compute the scores
    all_scores = []
    all_alphas = np.linspace(-12, 0, 5)
    for a in all_alphas:
        lr = linear_model.LogisticRegression(
            solver='lbfgs', C=np.exp(-a), fit_intercept=False, tol=1e-6,
            max_iter=100)
        lr.fit(Xt, yt)
        score_scv = linear_model.logistic._logistic_loss(
            lr.coef_.ravel(), Xh, yh, 0)
        all_scores.append(score_scv)
    all_scores = np.array(all_scores)

    best_alpha = all_alphas[np.argmin(all_scores)]

    clf = LogisticRegressionCV(verbose=True)
    clf.fit(Xt, yt, Xh, yh)
    assert np.abs(clf.alpha_ - best_alpha) < 0.5
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号