steingan_lsun.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:SteinGAN 作者: DartML 项目源码 文件源码
def svgd_gradient(X0):

    hidden, _, mse = discrim(X0)
    grad = -1.0 * T.grad( mse.sum(), X0)

    kxy, neighbors, h = rbf_kernel(hidden)  #TODO

    coff = T.exp( - T.sum((hidden[neighbors] - hidden)**2, axis=1) / h**2 / 2.0 )
    v = coff.dimshuffle(0, 'x') * (-hidden[neighbors] + hidden) / h**2

    X1 = X0[neighbors]
    hidden1, _, _ = discrim(X1)
    dxkxy = T.Lop(hidden1, X1, v)

    #svgd_grad = (T.dot(kxy, T.flatten(grad, 2)).reshape(dxkxy.shape) + dxkxy) / T.sum(kxy, axis=1).dimshuffle(0, 'x', 'x', 'x')
    svgd_grad = grad + dxkxy / 2.
    return grad, svgd_grad, dxkxy
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号