FISTA.py 文件源码

python
阅读 43 收藏 0 点赞 0 评论 0

项目:fastmat 作者: EMS-TU-Ilmenau 项目源码 文件源码
def FISTA(
        fmatA,
        arrB,
        numLambda=0.1,
        numMaxSteps=100
):
    '''
    Wrapper around the FISTA algrithm to allow processing of arrays of signals
        fmatA         - input system matrix
        arrB          - input data vector (measurements)
        numLambda     - balancing parameter in optimization problem
                        between data fidelity and sparsity
        numMaxSteps   - maximum number of steps to run
        numL          - step size during the conjugate gradient step
    '''

    if len(arrB.shape) > 2:
        raise ValueError("Only n x m arrays are supported for FISTA")

    # calculate the largest singular value to get the right step size
    numL = 1.0 / (fmatA.largestSV ** 2)
    t = 1
    arrX = np.zeros(
        (fmatA.numM, arrB.shape[1]),
        dtype=np.promote_types(np.float32, arrB.dtype)
    )
    # initial arrY
    arrY = np.copy(arrX)
    # start iterating
    for numStep in range(numMaxSteps):
        arrXold = np.copy(arrX)
        # do the gradient step and threshold
        arrStep = arrY - numL * fmatA.backward(fmatA.forward(arrY) - arrB)

        arrX = _softThreshold(arrStep, numL * numLambda * 0.5)

        # update t
        tOld =t
        t = (1 + np.sqrt(1 + 4 * t ** 2)) / 2
        # update arrY
        arrY = arrX + ((tOld - 1) / t) * (arrX - arrXold)
    # return the unthresholded values for all non-zero support elements
    return np.where(arrX != 0, arrStep, arrX)


################################################################################
###  Maintenance and Documentation
################################################################################

################################################## inspection interface
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号