def _update(self):
D = self.A.shape[1]
eta = 0.5 /np.float(D) # to tune
A_t = self.A.copy()
if self.A_true.any():
self.show_error()
sys.stdout.flush()
start = time.time()
A_inv = pinv(self.A)
Z = threshold(A_inv * self.Y, threshmin = self.alpha)
for i in range(self.inner_epo):
A_t = A_t + eta * (self.Y * Z.transpose() - A_t * Z * Z.transpose())
end = time.time()
self.time = self.time + end - start
return A_t
评论列表
文章目录