Cnnrescal.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:sictf 作者: malllabiisc 项目源码 文件源码
def _initR(X, A, lmbdaR):
    _log.debug('Initializing R (SVD) lambda R: %s' % str(lmbdaR))
    rank = A.shape[1]
    U, S, Vt = svd(A, full_matrices=False)
    Shat = kron(S, S)
    Shat = (Shat / (Shat ** 2 + lmbdaR)).reshape(rank, rank)
    R = []
    ep = 1e-9
    for i in range(len(X)): # parallelize
        Rn = Shat * dot(U.T, X[i].dot(U))
        Rn = dot(Vt.T, dot(Rn, Vt))

        negativeVal = Rn.min()
        Rn.__iadd__(-negativeVal+ep)
        # if Rn.min() < 0 :
        #   print("Negative Rn!")
        #   raw_input("Press Enter: ")
        # Rn = np.eye(rank)
        # Rn = dot(A.T,A)

        R.append(Rn)
    print('Initialized R')
    return R

# ------------------ Update V ------------------------------------------------
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号