def f(W,G,y,hparams):
# f = -1/N*sum_t log(exp(w(yt)'gt)/sum_k exp(wk'gt)) + l*||W||
# = -1/N*sum_t [w(yt)'*gt - log(sum_k exp(wk'gt))] + l*||W||
# = -1/N*sum(sum(W(:,y).*G,1),2) + 1/N*sum(log(sumexpWG),2) + l*sum(sum(W.^2));
#K,l = hparams
K = hparams['K']
l = hparams['l']
d,N = G.shape
W = W.reshape((d,K))
WG = np.dot(W.T,G) # K x N
WG -= np.kron(np.ones((K,1)),WG.max(axis=0).reshape(1,N))
#WG_max = WG.max(axis=0).reshape((1,N))
#expWG = np.exp(WG-np.kron(np.ones((K,1)),WG_max)) # K x N
expWG = np.exp(WG) # K x N
sumexpWG = expWG.sum(axis=0) # N x 1
WyG = WG[y,range(N)]
#WyG -= WG_max
fval = -1.0/N*(WyG).sum() \
+ 1.0/N*np.log(sumexpWG).sum() \
+ l*(W**2).sum()#(axis=(0,1))
return fval
评论列表
文章目录