def crossEntrGrad(y, trueY, G):
k,n = G.shape
assert(len(y) == n)
y_ = np.copy(y)
eps = 1e-8
y_ = np.clip(y_, eps, 1.-eps)
# H = np.bmat([[np.diag(1./y_ + 1./(1.-y_)), G.T, np.zeros((n,1))],
# [G, np.zeros((k,k)), -np.ones((k,1))],
# [np.zeros((1,n)), -np.ones((1,k)), np.zeros((1,1))]])
# c = -np.linalg.solve(H, np.concatenate([trueY/y_-(1-trueY)/(1-y_), np.zeros(k+1)]))
# b = np.concatenate([trueY/y_-(1-trueY)/(1-y_), np.zeros(k+1)])
# cy, clam, ct = np.split(c, [n, n+k])
# cy[(y == 0) | (y == 1)] = 0
z = 1./y_ + 1./(1.-y_)
zinv = 1./z
G_zinv = G*zinv
G_zinv_GT = np.dot(G_zinv, G.T)
H = np.bmat([[G_zinv_GT, np.ones((k,1))], [np.ones((1,k)), np.zeros((1,1))]])
dl = trueY/y_-(1-trueY)/(1-y_)
b = np.concatenate([np.dot(G_zinv, dl), np.zeros(1)])
clamt = np.linalg.solve(H, b)
clam, ct = np.split(clamt, [k])
cy = zinv*dl - np.dot((G*zinv).T, clam)
cy[(y == 0) | (y == 1)] = 0
return cy, clam, ct
评论列表
文章目录