def ksvd(Y, D, X, n_cycles=1, verbose=True):
n_atoms = D.shape[1]
n_features, n_samples = Y.shape
unused_atoms = []
R = Y - fast_dot(D, X)
for c in range(n_cycles):
for k in range(n_atoms):
if verbose:
sys.stdout.write("\r" + "k-svd..." + ":%3.2f%%" % ((k / float(n_atoms)) * 100))
sys.stdout.flush()
# find all the datapoints that use the kth atom
omega_k = X[k, :] != 0
if not np.any(omega_k):
unused_atoms.append(k)
continue
# the residual due to all the other atoms but k
Rk = R[:, omega_k] + np.outer(D[:, k], X[k, omega_k])
U, S, V = randomized_svd(Rk, n_components=1, n_iter=10, flip_sign=False)
D[:, k] = U[:, 0]
X[k, omega_k] = V[0, :] * S[0]
# update the residual
R[:, omega_k] = Rk - np.outer(D[:, k], X[k, omega_k])
print ""
return D, X, unused_atoms
评论列表
文章目录