learningAlg.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:MinimaxFilter 作者: jihunhamm 项目源码 文件源码
def f(W,G,y,hparams):
        # f = -1/N*sum_t log(exp(w(yt)'gt)/sum_k exp(wk'gt)) + l*||W||
        # = -1/N*sum_t [w(yt)'*gt - log(sum_k exp(wk'gt))] + l*||W||
        # = -1/N*sum(sum(W(:,y).*G,1),2) + 1/N*sum(log(sumexpWG),2) + l*sum(sum(W.^2));

        #K,l = hparams
        K = hparams['K']
        l = hparams['l']
        d,N = G.shape
        W = W.reshape((d,K))

        WG = np.dot(W.T,G) # K x N
        WG -= np.kron(np.ones((K,1)),WG.max(axis=0).reshape(1,N))
        #WG_max = WG.max(axis=0).reshape((1,N))
        #expWG = np.exp(WG-np.kron(np.ones((K,1)),WG_max)) # K x N
        expWG = np.exp(WG) # K x N
        sumexpWG = expWG.sum(axis=0) # N x 1
        WyG = WG[y,range(N)]
        #WyG -= WG_max

        fval = -1.0/N*(WyG).sum() \
            + 1.0/N*np.log(sumexpWG).sum() \
            + l*(W**2).sum()#(axis=(0,1))

        return fval
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号