def update_X(X, mu, k=6):
#U, S, VT = svdp(X, k=k)
U, S, VT = svds(X, k=k, which='LM')
P = np.c_[np.ones((k, 1)), 1-S, 1./2./mu-S]
sigma_star = np.zeros(k)
for t in range(k):
p = P[t, :]
delta = p[1]**2 - 4 * p[0] * p[2]
if delta <= 0:
sigma_star[t] = 0.
else:
solution = np.roots(p)
solution = solution.tolist()
solution.sort(key=abs)
solution = np.array(solution)
if solution[0] * solution[1] <= 0:
sigma_star[t] = solution[1]
elif solution[1] < 0:
sigma_star[t] = 0.
else:
f = np.log(1 + solution[1]) + mu * (solution[1] - s[t])**2
if f > mu *s[t]**2:
sigma_star[t] = 0.
else:
sigma_star[t] = solution[1]
sigma_star = sp.csr_matrix(np.diag(sigma_star))
sigma_star = safe_sparse_dot(safe_sparse_dot(U, sigma_star), VT)
sigma_star[abs(sigma_star)<1e-10] = 0
return sp.lil_matrix(sigma_star)
评论列表
文章目录