def nn_ksvd(Y, D, X, n_cycles=1, verbose=True):
# the non-negative variant
n_atoms = D.shape[1]
n_features, n_samples = Y.shape
unused_atoms = []
R = Y - fast_dot(D, X)
for k in range(n_atoms):
if verbose:
sys.stdout.write("\r" + "k-svd..." + ":%3.2f%%" % ((k / float(n_atoms)) * 100))
sys.stdout.flush()
# find all the datapoints that use the kth atom
omega_k = X[k, :] != 0
if not np.any(omega_k):
unused_atoms.append(k)
continue
# the residual due to all the other atoms but k
Rk = R[:, omega_k] + np.outer(D[:, k], X[k, omega_k])
try:
U, S, V = randomized_svd(Rk, n_components=1, n_iter=50, flip_sign=False)
except:
warnings.warn('SVD error')
continue
d = U[:, 0]
x = V[0, :] * S[0]
# projection to the constraint set
d[d < 0] = 0
x[x < 0] = 0
dTd = np.dot(d, d)
xTx = np.dot(x, x)
if dTd <= np.finfo('float').eps or xTx <= np.finfo('float').eps:
continue
for j in range(n_cycles):
d = np.dot(Rk, x) / np.dot(x, x)
d[d < 0] = 0
x = np.dot(d.T, Rk) / np.dot(d, d)
x[x < 0] = 0
_norm = norm(d)
d = d / _norm
x = x * _norm
D[:, k] = d
X[k, omega_k] = x
# update the residual
R[:, omega_k] = Rk - np.outer(D[:, k], X[k, omega_k])
print ""
return D, X, unused_atoms
评论列表
文章目录