sum_lazy_variable_test.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:gpytorch 作者: jrg365 项目源码 文件源码
def test_exact_posterior():
    train_mean = Variable(torch.randn(4))
    train_y = Variable(torch.randn(4))
    test_mean = Variable(torch.randn(4))

    # Test case
    c1_var = Variable(torch.Tensor([5, 1, 2, 0]), requires_grad=True)
    c2_var = Variable(torch.Tensor([6, 0, 1, -1]), requires_grad=True)
    indices = Variable(torch.arange(0, 4).long().view(4, 1))
    values = Variable(torch.ones(4).view(4, 1))
    toeplitz_1 = InterpolatedLazyVariable(ToeplitzLazyVariable(c1_var), indices, values, indices, values)
    toeplitz_2 = InterpolatedLazyVariable(ToeplitzLazyVariable(c2_var), indices, values, indices, values)
    sum_lv = toeplitz_1 + toeplitz_2

    # Actual case
    actual = sum_lv.evaluate()

    # Test forward
    actual_alpha = gpytorch.posterior_strategy(actual).exact_posterior_alpha(train_mean, train_y)
    actual_mean = gpytorch.posterior_strategy(actual).exact_posterior_mean(test_mean, actual_alpha)
    sum_lv_alpha = sum_lv.posterior_strategy().exact_posterior_alpha(train_mean, train_y)
    sum_lv_mean = sum_lv.posterior_strategy().exact_posterior_mean(test_mean, sum_lv_alpha)
    assert(torch.norm(actual_mean.data - sum_lv_mean.data) < 1e-4)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号