def _kernel_matching(q1_x, q1_mu, xt_x, xt_mu, radius) :
"""
Given two measures q1 and xt represented by locations/weights arrays,
outputs a kernel-fidelity term and an empty 'info' array.
"""
K_qq, K_qx, K_xx = _cross_kernels(q1_x, xt_x, radius)
cost = .5 * ( torch.sum(K_qq * torch.ger(q1_mu,q1_mu)) \
+ torch.sum(K_xx * torch.ger(xt_mu,xt_mu)) \
-2*torch.sum(K_qx * torch.ger(q1_mu,xt_mu)) )
# Info = the 2D graph of the blurred distance function
# Increase res if you want to get nice smooth pictures...
res = 10 ; ticks = np.linspace( 0, 1, res + 1)[:-1] + 1/(2*res)
X,Y = np.meshgrid( ticks, ticks )
points = Variable(torch.from_numpy(np.vstack( (X.ravel(), Y.ravel()) ).T).type(dtype), requires_grad=False)
info = _k( points, q1_x , radius ) @ q1_mu \
- _k( points, xt_x , radius ) @ xt_mu
return [cost , info.view( (res,res) ) ]
评论列表
文章目录