mul_lazy_variable_test.py 文件源码

python
阅读 16 收藏 0 点赞 0 评论 0

项目:gpytorch 作者: jrg365 项目源码 文件源码
def test_matmul_approx():
    class KissGPModel(gpytorch.GridInducingPointModule):
        def __init__(self):
            super(KissGPModel, self).__init__(grid_size=300, grid_bounds=[(0, 1)])
            self.mean_module = ConstantMean(constant_bounds=(-1, 1))
            covar_module = RBFKernel(log_lengthscale_bounds=(-100, 100))
            covar_module.log_lengthscale.data = torch.FloatTensor([-2])
            self.covar_module = covar_module

        def forward(self, x):
            mean_x = self.mean_module(x)
            covar_x = self.covar_module(x)
            return GaussianRandomVariable(mean_x, covar_x)

    model = KissGPModel()

    n = 100
    d = 4

    lazy_var_list = []
    lazy_var_eval_list = []

    for i in range(d):
        x = Variable(torch.rand(n))
        y = Variable(torch.rand(n))
        model.condition(x, y)
        toeplitz_var = model(x).covar()
        lazy_var_list.append(toeplitz_var)
        lazy_var_eval_list.append(toeplitz_var.evaluate().data)

    mul_lazy_var = MulLazyVariable(*lazy_var_list, matmul_mode='approximate', max_iter=30)
    mul_lazy_var_eval = torch.ones(n, n)
    for i in range(d):
        mul_lazy_var_eval *= (lazy_var_eval_list[i].matmul(torch.eye(lazy_var_eval_list[i].size()[0])))

    vec = torch.randn(n)

    actual = mul_lazy_var_eval.matmul(vec)
    res = mul_lazy_var.matmul(Variable(vec)).data

    assert torch.norm(actual - res) / torch.norm(actual) < 1e-2
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号