def _omp(x, D, Gram, alpha, n_nonzero_coefs=None, tol=None):
_, n_atoms = D.shape
# the dict indexes of the atoms this datapoint uses
Dx = np.array([]).astype(int)
z = np.zeros(n_atoms)
# the residual
r = np.copy(x)
i = 0
if n_nonzero_coefs is not None:
tol = 1e-10
def cont_criterion():
not_reached_sparsity = i < n_nonzero_coefs
return (not_reached_sparsity and norm(r) > tol)
else:
cont_criterion = lambda: norm(r) >= tol
while (cont_criterion()):
# find the atom that correlates the
# most with the residual
k = np.argmax(np.abs(alpha))
if k in Dx:
break
Dx = np.append(Dx, k)
# solve the Least Squares problem
# to find the coefs z
DI = D[:, Dx]
G = Gram[Dx, :][:, Dx]
G = np.atleast_2d(G)
try:
G_inv = inv(G)
except LinAlgError:
print gram_singular_msg
break
z[Dx] = np.dot(G_inv, np.dot(D.T, x)[Dx])
r = x - np.dot(D[:, Dx], z[Dx])
alpha = np.dot(D.T, r)
i += 1
return z
评论列表
文章目录