def update_W(m_opts, m_vars):
# print "Updating W"
if not m_opts['use_grad']:
sigma = m_vars['X_batch_T'].dot(m_vars['X_batch']) + m_opts['lam_w']*ssp.eye(m_vars['n_features'], format="csr")
m_vars['sigma_W'] = (1-m_vars['gamma'])*m_vars['sigma_W'] + m_vars['gamma']*sigma
x = m_vars['X_batch'].T.dot(m_vars['U_batch'])
m_vars['x_W'] = (1-m_vars['gamma'])*m_vars['x_W'] + m_vars['gamma']*x
if m_opts['use_cg'] != True: # For the Ridge regression on W matrix with the closed form solutions
if ssp.issparse(m_vars['sigma_W']):
m_vars['sigma_W'] = m_vars['sigma_W'].todense()
sigma = linalg.inv(m_vars['sigma_W']) # O(N^3) time for N x N matrix inversion
m_vars['W'] = np.asarray(sigma.dot(m_vars['x_W'])).T
else: # For the CG on the ridge loss to calculate W matrix
if not m_opts['use_grad']:
# assert m_vars['X_batch'].shape[0] == m_vars['U_batch'].shape[0]
X = m_vars['sigma_W']
for i in range(m_opts['n_components']):
y = m_vars['x_W'][:, i]
w,info = sp_linalg.cg(X, y, x0=m_vars['W'][i,:], maxiter=m_opts['cg_iters'])
if info < 0:
print "WARNING: sp_linalg.cg info: illegal input or breakdown"
m_vars['W'][i, :] = w.T
else:
''' Solving X*W' = U '''
# print "Using grad!"
my_invert = lambda x: x if x<1 else 1.0/x
l2_norm = lambda x: np.sqrt((x**2).sum())
def clip_by_norm(x, clip_max):
x_norm = l2_norm(x)
if x_norm > clip_max:
# print "Clipped!",clip_max
x = clip_max*(x/x_norm)
return x
lr = m_opts['grad_alpha']*(1.0 + np.arange(m_opts['cg_iters']*10))**(-0.9) #(-0.75)
try:
W_old = m_vars['W'].copy()
tail_norm, curr_norm = 1.0,1.0
for iter_idx in range(m_opts['cg_iters']*10):
grad = m_vars['X_batch_T'].dot(m_vars['X_batch'].dot(m_vars['W'].T) - m_vars['U_batch'])
grad = lr[iter_idx]*(grad.T + m_opts['lam_w']*m_vars['W'])
tail_norm = 0.5*curr_norm + (1-0.5)*tail_norm
curr_norm = l2_norm(grad)
if curr_norm < 1e-15:
return
elif iter_idx > 10 and my_invert(np.abs(tail_norm/curr_norm)) > 0.8:
# print "Halved!"
lr = lr/2.0
m_vars['W'] = m_vars['W'] - clip_by_norm(grad, 1e0) # Clip by norm
Delta_W = l2_norm(m_vars['W']-W_old)
except FloatingPointError:
print "FloatingPointError in:"
print grad
assert False
评论列表
文章目录