test_cubic_interpolation.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:gpytorch 作者: jrg365 项目源码 文件源码
def test_interpolation():
    x = torch.linspace(0.01, 1, 100)
    grid = torch.linspace(-0.05, 1.05, 50)
    J, C = Interpolation().interpolate(grid, x)
    W = utils.toeplitz.index_coef_to_sparse(J, C, len(grid))
    test_func_grid = grid.pow(2)
    test_func_x = x.pow(2)

    interp_func_x = torch.dsmm(W, test_func_grid.unsqueeze(1)).squeeze()

    assert all(torch.abs(interp_func_x - test_func_x) / (test_func_x + 1e-10) < 1e-5)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号