rbf_kernel_test.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:gpytorch 作者: jrg365 项目源码 文件源码
def test_computes_radial_basis_function_gradient():
    a = torch.Tensor([4, 2, 8]).view(3, 1)
    b = torch.Tensor([0, 2, 2]).view(3, 1)
    lengthscale = 2

    kernel = RBFKernel().initialize(log_lengthscale=math.log(lengthscale))
    kernel.eval()
    param = Variable(torch.Tensor(3, 3).fill_(math.log(lengthscale)), requires_grad=True)
    diffs = Variable(a.expand(3, 3) - b.expand(3, 3).transpose(0, 1))
    actual_output = (-(diffs ** 2) / param.exp()).exp()
    actual_output.backward(torch.eye(3))
    actual_param_grad = param.grad.data.sum()

    output = kernel(Variable(a), Variable(b))
    output.backward(gradient=torch.eye(3))
    res = kernel.log_lengthscale.grad.data
    assert(torch.norm(res - actual_param_grad) < 1e-5)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号