ksvd.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:Lyssandra 作者: ektormak 项目源码 文件源码
def nn_ksvd(Y, D, X, n_cycles=1, verbose=True):
    # the non-negative variant
    n_atoms = D.shape[1]
    n_features, n_samples = Y.shape
    unused_atoms = []
    R = Y - fast_dot(D, X)

    for k in range(n_atoms):
        if verbose:
            sys.stdout.write("\r" + "k-svd..." + ":%3.2f%%" % ((k / float(n_atoms)) * 100))
            sys.stdout.flush()
        # find all the datapoints that use the kth atom
        omega_k = X[k, :] != 0
        if not np.any(omega_k):
            unused_atoms.append(k)
            continue
        # the residual due to all the other atoms but k
        Rk = R[:, omega_k] + np.outer(D[:, k], X[k, omega_k])
        try:
            U, S, V = randomized_svd(Rk, n_components=1, n_iter=50, flip_sign=False)
        except:
            warnings.warn('SVD error')
            continue

        d = U[:, 0]
        x = V[0, :] * S[0]
        # projection to the constraint set
        d[d < 0] = 0
        x[x < 0] = 0

        dTd = np.dot(d, d)
        xTx = np.dot(x, x)
        if dTd <= np.finfo('float').eps or xTx <= np.finfo('float').eps:
            continue

        for j in range(n_cycles):
            d = np.dot(Rk, x) / np.dot(x, x)
            d[d < 0] = 0
            x = np.dot(d.T, Rk) / np.dot(d, d)
            x[x < 0] = 0

        _norm = norm(d)
        d = d / _norm
        x = x * _norm
        D[:, k] = d
        X[k, omega_k] = x
        # update the residual
        R[:, omega_k] = Rk - np.outer(D[:, k], X[k, omega_k])
    print ""
    return D, X, unused_atoms
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号