model.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:GenEML 作者: nirbhayjm 项目源码 文件源码
def update_W(m_opts, m_vars):
    # print "Updating W"
    if not m_opts['use_grad']:
        sigma = m_vars['X_batch_T'].dot(m_vars['X_batch']) + m_opts['lam_w']*ssp.eye(m_vars['n_features'], format="csr")
        m_vars['sigma_W'] = (1-m_vars['gamma'])*m_vars['sigma_W'] + m_vars['gamma']*sigma

        x = m_vars['X_batch'].T.dot(m_vars['U_batch'])
        m_vars['x_W'] = (1-m_vars['gamma'])*m_vars['x_W'] + m_vars['gamma']*x

    if m_opts['use_cg'] != True: # For the Ridge regression on W matrix with the closed form solutions
        if ssp.issparse(m_vars['sigma_W']):
            m_vars['sigma_W'] = m_vars['sigma_W'].todense()
        sigma = linalg.inv(m_vars['sigma_W']) # O(N^3) time for N x N matrix inversion 
        m_vars['W'] = np.asarray(sigma.dot(m_vars['x_W'])).T
    else: # For the CG on the ridge loss to calculate W matrix
        if not m_opts['use_grad']:
            # assert m_vars['X_batch'].shape[0] == m_vars['U_batch'].shape[0]
            X = m_vars['sigma_W']
            for i in range(m_opts['n_components']):
                y = m_vars['x_W'][:, i]
                w,info = sp_linalg.cg(X, y, x0=m_vars['W'][i,:], maxiter=m_opts['cg_iters'])
                if info < 0:
                    print "WARNING: sp_linalg.cg info: illegal input or breakdown"
                m_vars['W'][i, :] = w.T
        else:
            ''' Solving X*W' = U '''
            # print "Using grad!"
            my_invert = lambda x: x if x<1 else 1.0/x
            l2_norm = lambda x: np.sqrt((x**2).sum())
            def clip_by_norm(x, clip_max):
                x_norm = l2_norm(x)
                if x_norm > clip_max:
                    # print "Clipped!",clip_max
                    x = clip_max*(x/x_norm)
                return x

            lr = m_opts['grad_alpha']*(1.0 + np.arange(m_opts['cg_iters']*10))**(-0.9) #(-0.75)
            try:
                W_old = m_vars['W'].copy()
                tail_norm, curr_norm = 1.0,1.0
                for iter_idx in range(m_opts['cg_iters']*10):
                    grad = m_vars['X_batch_T'].dot(m_vars['X_batch'].dot(m_vars['W'].T) - m_vars['U_batch'])
                    grad = lr[iter_idx]*(grad.T + m_opts['lam_w']*m_vars['W'])
                    tail_norm = 0.5*curr_norm + (1-0.5)*tail_norm
                    curr_norm = l2_norm(grad)
                    if curr_norm < 1e-15:
                        return
                    elif iter_idx > 10 and my_invert(np.abs(tail_norm/curr_norm)) > 0.8:
                        # print "Halved!"
                        lr = lr/2.0

                    m_vars['W'] = m_vars['W'] - clip_by_norm(grad, 1e0) # Clip by norm

                Delta_W = l2_norm(m_vars['W']-W_old)
            except FloatingPointError:
                print "FloatingPointError in:"
                print grad
                assert False
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号