test_pyglmnet.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:pyglmnet 作者: glm-tools 项目源码 文件源码
def test_cv():
    """Simple CV check."""
    # XXX: don't use scikit-learn for tests.
    X, y = make_regression()
    cv = KFold(X.shape[0], 5)

    glm_normal = GLM(distr='gaussian', alpha=0.01, reg_lambda=0.1)
    # check that it returns 5 scores
    scores = cross_val_score(glm_normal, X, y, cv=cv)
    assert_equal(len(scores), 5)

    param_grid = [{'alpha': np.linspace(0.01, 0.99, 2)},
                  {'reg_lambda': np.logspace(np.log(0.5), np.log(0.01),
                                             10, base=np.exp(1))}]
    glmcv = GridSearchCV(glm_normal, param_grid, cv=cv)
    glmcv.fit(X, y)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号