rbf_kernel_test.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:gpytorch 作者: jrg365 项目源码 文件源码
def test_computes_radial_basis_function():
    a = torch.Tensor([4, 2, 8]).view(3, 1)
    b = torch.Tensor([0, 2, 2]).view(3, 1)
    lengthscale = 2

    kernel = RBFKernel().initialize(log_lengthscale=math.log(lengthscale))
    kernel.eval()
    actual = torch.Tensor([
        [16, 4, 4],
        [4, 0, 0],
        [64, 36, 36],
    ]).mul_(-1).div_(lengthscale).exp()

    res = kernel(Variable(a), Variable(b)).data
    assert(torch.norm(res - actual) < 1e-5)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号