def dfdv(W,G,y,hparams):
# df/dwk = -1/N*sum(x(:,y==k),2) + 1/N*sum_t exp(wk'xt)*xt/(sum_k exp(wk'xt))] + l*2*wk
K = hparams['K']
l = hparams['l']
d,N = G.shape
shapeW = W.shape
W = W.reshape((d,K))
WG = np.dot(W.T,G) # K x N
WG -= np.kron(np.ones((K,1)),WG.max(axis=0).reshape(1,N))
expWG = np.exp(WG) # K x N
sumexpWG = expWG.sum(axis=0) # N x 1
df = np.zeros((d,K))
for k in range(K):
indk = np.where(y==k)[0]
df[:,k] = -1./N*G[:,indk].sum(axis=1).reshape((d,)) \
+ 1./N*np.dot(G,(expWG[k,:]/sumexpWG).T).reshape((d,)) \
+ 2.*l*W[:,k].reshape((d,))
assert np.isnan(df).any()==False
return df.reshape(shapeW)
评论列表
文章目录