test_pyglmnet.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:pyglmnet 作者: glm-tools 项目源码 文件源码
def test_group_lasso():
    """Group Lasso test."""
    n_samples, n_features = 100, 90

    # assign group ids
    groups = np.zeros(90)
    groups[0:29] = 1
    groups[30:59] = 2
    groups[60:] = 3

    # sample random coefficients
    beta0 = np.random.normal(0.0, 1.0, 1)
    beta = np.random.normal(0.0, 1.0, n_features)
    beta[groups == 2] = 0.

    # create an instance of the GLM class
    glm_group = GLM(distr='softplus', alpha=1.)

    # simulate training data
    Xr = np.random.normal(0.0, 1.0, [n_samples, n_features])
    yr = simulate_glm(glm_group.distr, beta0, beta, Xr)

    # scale and fit
    scaler = StandardScaler().fit(Xr)
    glm_group.fit(scaler.transform(Xr), yr)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号