test_sag.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:Parallel-SGD 作者: angadgill 项目源码 文件源码
def test_get_auto_step_size():
    X = np.array([[1, 2, 3], [2, 3, 4], [2, 3, 2]], dtype=np.float64)
    alpha = 1.2
    fit_intercept = False
    # sum the squares of the second sample because that's the largest
    max_squared_sum = 4 + 9 + 16
    max_squared_sum_ = row_norms(X, squared=True).max()
    assert_almost_equal(max_squared_sum, max_squared_sum_, decimal=4)

    for fit_intercept in (True, False):
        step_size_sqr = 1.0 / (max_squared_sum + alpha + int(fit_intercept))
        step_size_log = 4.0 / (max_squared_sum + 4.0 * alpha +
                               int(fit_intercept))

        step_size_sqr_ = get_auto_step_size(max_squared_sum_, alpha, "squared",
                                            fit_intercept)
        step_size_log_ = get_auto_step_size(max_squared_sum_, alpha, "log",
                                            fit_intercept)

        assert_almost_equal(step_size_sqr, step_size_sqr_, decimal=4)
        assert_almost_equal(step_size_log, step_size_log_, decimal=4)

    msg = 'Unknown loss function for SAG solver, got wrong instead of'
    assert_raise_message(ValueError, msg, get_auto_step_size,
                         max_squared_sum_, alpha, "wrong", fit_intercept)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号