def FISTA(
fmatA,
arrB,
numLambda=0.1,
numMaxSteps=100
):
'''
Wrapper around the FISTA algrithm to allow processing of arrays of signals
fmatA - input system matrix
arrB - input data vector (measurements)
numLambda - balancing parameter in optimization problem
between data fidelity and sparsity
numMaxSteps - maximum number of steps to run
numL - step size during the conjugate gradient step
'''
if len(arrB.shape) > 2:
raise ValueError("Only n x m arrays are supported for FISTA")
# calculate the largest singular value to get the right step size
numL = 1.0 / (fmatA.largestSV ** 2)
t = 1
arrX = np.zeros(
(fmatA.numM, arrB.shape[1]),
dtype=np.promote_types(np.float32, arrB.dtype)
)
# initial arrY
arrY = np.copy(arrX)
# start iterating
for numStep in range(numMaxSteps):
arrXold = np.copy(arrX)
# do the gradient step and threshold
arrStep = arrY - numL * fmatA.backward(fmatA.forward(arrY) - arrB)
arrX = _softThreshold(arrStep, numL * numLambda * 0.5)
# update t
tOld =t
t = (1 + np.sqrt(1 + 4 * t ** 2)) / 2
# update arrY
arrY = arrX + ((tOld - 1) / t) * (arrX - arrXold)
# return the unthresholded values for all non-zero support elements
return np.where(arrX != 0, arrStep, arrX)
################################################################################
### Maintenance and Documentation
################################################################################
################################################## inspection interface
评论列表
文章目录