lddmm_pytorch.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:lddmm-ot 作者: jeanfeydy 项目源码 文件源码
def _kernel_matching(q1_x, q1_mu, xt_x, xt_mu, radius) :
    """
    Given two measures q1 and xt represented by locations/weights arrays, 
    outputs a kernel-fidelity term and an empty 'info' array.
    """
    K_qq, K_qx, K_xx = _cross_kernels(q1_x, xt_x, radius)
    cost = .5 * (   torch.sum(K_qq * torch.ger(q1_mu,q1_mu)) \
                 +  torch.sum(K_xx * torch.ger(xt_mu,xt_mu)) \
                 -2*torch.sum(K_qx * torch.ger(q1_mu,xt_mu))  )

    # Info = the 2D graph of the blurred distance function
    # Increase res if you want to get nice smooth pictures...
    res    = 10 ; ticks = np.linspace( 0, 1, res + 1)[:-1] + 1/(2*res) 
    X,Y    = np.meshgrid( ticks, ticks )
    points = Variable(torch.from_numpy(np.vstack( (X.ravel(), Y.ravel()) ).T).type(dtype), requires_grad=False)

    info   = _k( points, q1_x , radius ) @ q1_mu \
           - _k( points, xt_x , radius ) @ xt_mu
    return [cost , info.view( (res,res) ) ]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号