def test_computes_radial_basis_function():
a = torch.Tensor([4, 2, 8]).view(3, 1)
b = torch.Tensor([0, 2, 2]).view(3, 1)
lengthscale = 2
kernel = RBFKernel().initialize(log_lengthscale=math.log(lengthscale))
kernel.eval()
actual = torch.Tensor([
[16, 4, 4],
[4, 0, 0],
[64, 36, 36],
]).mul_(-1).div_(lengthscale).exp()
res = kernel(Variable(a), Variable(b)).data
assert(torch.norm(res - actual) < 1e-5)
评论列表
文章目录