admm.py 文件源码

python
阅读 47 收藏 0 点赞 0 评论 0

项目:l1l2py 作者: slipguru 项目源码 文件源码
def enet_admm(X, y, z=None, rho=1.0, alpha=1.0, max_iter=1000, abs_tol=1e-6,
              rel_tol=1e-4, tau=0.5, mu=0.5):
    n, d = X.shape

    XTy = np.dot(X.T, y)

    x = np.zeros(d)
    z = np.zeros(d)
    u = np.zeros(d)

    L, U = factor(X, rho, mu)

    for k in xrange(max_iter):
        # x-update
        q = 2. / n * XTy + rho * (z - u)    # temporary value

        if n >= d:      # if skinny
            x = la.solve_triangular(U, la.solve_triangular(L, q, lower=True),
                                    lower=False)
        else:            # if fat
            tmp = la.solve_triangular(U, la.solve_triangular(L, np.dot(X, q),
                                      lower=True), lower=False)
            x = q / rho - np.dot(X.T, tmp) * (2. / (n * rho * rho))

        # z-update with relaxation
        zold = z
        x_hat = alpha * x + (1 - alpha) * zold
        z = shrinkage(x_hat + u, tau / rho)

        # u-update
        u += (x_hat - z)

        # Stopping
        r_norm = la.norm(x - z)
        s_norm = la.norm(-rho * (z - zold))

        eps_pri = np.sqrt(d) * abs_tol + rel_tol * max(la.norm(x), la.norm(-z))
        eps_dual = np.sqrt(d) * abs_tol + rel_tol * la.norm(rho * u)

        if (r_norm < eps_pri) and (s_norm < eps_dual):
            break

    return z, s_norm, eps_dual, k + 1
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号