def enet_admm(X, y, z=None, rho=1.0, alpha=1.0, max_iter=1000, abs_tol=1e-6,
rel_tol=1e-4, tau=0.5, mu=0.5):
n, d = X.shape
XTy = np.dot(X.T, y)
x = np.zeros(d)
z = np.zeros(d)
u = np.zeros(d)
L, U = factor(X, rho, mu)
for k in xrange(max_iter):
# x-update
q = 2. / n * XTy + rho * (z - u) # temporary value
if n >= d: # if skinny
x = la.solve_triangular(U, la.solve_triangular(L, q, lower=True),
lower=False)
else: # if fat
tmp = la.solve_triangular(U, la.solve_triangular(L, np.dot(X, q),
lower=True), lower=False)
x = q / rho - np.dot(X.T, tmp) * (2. / (n * rho * rho))
# z-update with relaxation
zold = z
x_hat = alpha * x + (1 - alpha) * zold
z = shrinkage(x_hat + u, tau / rho)
# u-update
u += (x_hat - z)
# Stopping
r_norm = la.norm(x - z)
s_norm = la.norm(-rho * (z - zold))
eps_pri = np.sqrt(d) * abs_tol + rel_tol * max(la.norm(x), la.norm(-z))
eps_dual = np.sqrt(d) * abs_tol + rel_tol * la.norm(rho * u)
if (r_norm < eps_pri) and (s_norm < eps_dual):
break
return z, s_norm, eps_dual, k + 1
评论列表
文章目录