def svgd_gradient(X0):
hidden, _, mse = discrim(X0)
grad = -1.0 * T.grad( mse.sum(), X0)
kxy, neighbors, h = rbf_kernel(hidden) #TODO
coff = T.exp( - T.sum((hidden[neighbors] - hidden)**2, axis=1) / h**2 / 2.0 )
v = coff.dimshuffle(0, 'x') * (-hidden[neighbors] + hidden) / h**2
X1 = X0[neighbors]
hidden1, _, _ = discrim(X1)
dxkxy = T.Lop(hidden1, X1, v)
#svgd_grad = (T.dot(kxy, T.flatten(grad, 2)).reshape(dxkxy.shape) + dxkxy) / T.sum(kxy, axis=1).dimshuffle(0, 'x', 'x', 'x')
svgd_grad = grad + dxkxy / 2.
return grad, svgd_grad, dxkxy
评论列表
文章目录