update_d.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:alphacsc 作者: alphacsc 项目源码 文件源码
def solve_unit_norm_dual(lhs, rhs, lambd0, factr=1e7, debug=False,
                         lhs_is_toeplitz=False):
    if np.all(rhs == 0):
        return np.zeros(lhs.shape[0]), 0.

    n_atoms = lambd0.shape[0]
    n_times_atom = lhs.shape[0] // n_atoms

    # precompute SVD
    # U, s, V = linalg.svd(lhs)

    if lhs_is_toeplitz:
        # first column of the toeplitz matrix lhs
        lhs_c = lhs[0, :]

        # lhs will not stay toeplitz if we add different lambd on the diagonal
        assert n_atoms == 1

        def x_star(lambd):
            lambd += 1e-14  # avoid numerical issues
            # lhs_inv = np.dot(V.T / (s + np.repeat(lambd, n_times_atom)), U.T)
            # return np.dot(lhs_inv, rhs)
            lhs_c_copy = lhs_c.copy()
            lhs_c_copy[0] += lambd
            return linalg.solve_toeplitz(lhs_c_copy, rhs)

    else:

        def x_star(lambd):
            lambd += 1e-14  # avoid numerical issues
            # lhs_inv = np.dot(V.T / (s + np.repeat(lambd, n_times_atom)), U.T)
            # return np.dot(lhs_inv, rhs)
            return linalg.solve(lhs + np.diag(np.repeat(lambd, n_times_atom)),
                                rhs)

    def dual(lambd):
        x_hats = x_star(lambd)
        norms = linalg.norm(x_hats.reshape(-1, n_times_atom), axis=1)
        return (x_hats.T.dot(lhs).dot(x_hats) - 2 * rhs.T.dot(x_hats) + np.dot(
            lambd, norms ** 2 - 1.))

    def grad_dual(lambd):
        x_hats = x_star(lambd).reshape(-1, n_times_atom)
        return linalg.norm(x_hats, axis=1) ** 2 - 1.

    def func(lambd):
        return -dual(lambd)

    def grad(lambd):
        return -grad_dual(lambd)

    bounds = [(0., None) for idx in range(0, n_atoms)]
    if debug:
        assert optimize.check_grad(func, grad, lambd0) < 1e-5
    lambd_hats, _, _ = optimize.fmin_l_bfgs_b(func, x0=lambd0, fprime=grad,
                                              bounds=bounds, factr=factr)
    x_hat = x_star(lambd_hats)
    return x_hat, lambd_hats
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号