learningAlg.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:MinimaxFilter 作者: jihunhamm 项目源码 文件源码
def dfdv(W,G,y,hparams):
        # df/dwk = -1/N*sum(x(:,y==k),2) + 1/N*sum_t exp(wk'xt)*xt/(sum_k exp(wk'xt))] + l*2*wk
        K = hparams['K']
        l = hparams['l']
        d,N = G.shape
        shapeW = W.shape
        W = W.reshape((d,K))

        WG = np.dot(W.T,G) # K x N
        WG -= np.kron(np.ones((K,1)),WG.max(axis=0).reshape(1,N))
        expWG = np.exp(WG) # K x N
        sumexpWG = expWG.sum(axis=0) # N x 1
        df = np.zeros((d,K))
        for k in range(K):
            indk = np.where(y==k)[0]    
            df[:,k] = -1./N*G[:,indk].sum(axis=1).reshape((d,)) \
                + 1./N*np.dot(G,(expWG[k,:]/sumexpWG).T).reshape((d,)) \
                + 2.*l*W[:,k].reshape((d,))

        assert np.isnan(df).any()==False        

        return df.reshape(shapeW)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号