def solve_unit_norm_dual(lhs, rhs, lambd0, factr=1e7, debug=False,
lhs_is_toeplitz=False):
if np.all(rhs == 0):
return np.zeros(lhs.shape[0]), 0.
n_atoms = lambd0.shape[0]
n_times_atom = lhs.shape[0] // n_atoms
# precompute SVD
# U, s, V = linalg.svd(lhs)
if lhs_is_toeplitz:
# first column of the toeplitz matrix lhs
lhs_c = lhs[0, :]
# lhs will not stay toeplitz if we add different lambd on the diagonal
assert n_atoms == 1
def x_star(lambd):
lambd += 1e-14 # avoid numerical issues
# lhs_inv = np.dot(V.T / (s + np.repeat(lambd, n_times_atom)), U.T)
# return np.dot(lhs_inv, rhs)
lhs_c_copy = lhs_c.copy()
lhs_c_copy[0] += lambd
return linalg.solve_toeplitz(lhs_c_copy, rhs)
else:
def x_star(lambd):
lambd += 1e-14 # avoid numerical issues
# lhs_inv = np.dot(V.T / (s + np.repeat(lambd, n_times_atom)), U.T)
# return np.dot(lhs_inv, rhs)
return linalg.solve(lhs + np.diag(np.repeat(lambd, n_times_atom)),
rhs)
def dual(lambd):
x_hats = x_star(lambd)
norms = linalg.norm(x_hats.reshape(-1, n_times_atom), axis=1)
return (x_hats.T.dot(lhs).dot(x_hats) - 2 * rhs.T.dot(x_hats) + np.dot(
lambd, norms ** 2 - 1.))
def grad_dual(lambd):
x_hats = x_star(lambd).reshape(-1, n_times_atom)
return linalg.norm(x_hats, axis=1) ** 2 - 1.
def func(lambd):
return -dual(lambd)
def grad(lambd):
return -grad_dual(lambd)
bounds = [(0., None) for idx in range(0, n_atoms)]
if debug:
assert optimize.check_grad(func, grad, lambd0) < 1e-5
lambd_hats, _, _ = optimize.fmin_l_bfgs_b(func, x0=lambd0, fprime=grad,
bounds=bounds, factr=factr)
x_hat = x_star(lambd_hats)
return x_hat, lambd_hats
评论列表
文章目录