icnn_ebundle.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:icnn 作者: locuslab 项目源码 文件源码
def crossEntrGrad(y, trueY, G):
    k,n = G.shape
    assert(len(y) == n)
    y_ = np.copy(y)
    eps = 1e-8
    y_ = np.clip(y_, eps, 1.-eps)

    # H = np.bmat([[np.diag(1./y_ + 1./(1.-y_)), G.T, np.zeros((n,1))],
    #              [G, np.zeros((k,k)), -np.ones((k,1))],
    #              [np.zeros((1,n)), -np.ones((1,k)), np.zeros((1,1))]])

    # c = -np.linalg.solve(H, np.concatenate([trueY/y_-(1-trueY)/(1-y_), np.zeros(k+1)]))
    # b = np.concatenate([trueY/y_-(1-trueY)/(1-y_), np.zeros(k+1)])
    # cy, clam, ct = np.split(c, [n, n+k])
    # cy[(y == 0) | (y == 1)] = 0

    z = 1./y_ + 1./(1.-y_)
    zinv = 1./z
    G_zinv = G*zinv
    G_zinv_GT = np.dot(G_zinv, G.T)
    H = np.bmat([[G_zinv_GT, np.ones((k,1))], [np.ones((1,k)), np.zeros((1,1))]])
    dl = trueY/y_-(1-trueY)/(1-y_)
    b = np.concatenate([np.dot(G_zinv, dl), np.zeros(1)])
    clamt = np.linalg.solve(H, b)
    clam, ct = np.split(clamt, [k])
    cy = zinv*dl - np.dot((G*zinv).T, clam)
    cy[(y == 0) | (y == 1)] = 0
    return cy, clam, ct
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号