def test_matmul_approx():
class KissGPModel(gpytorch.GridInducingPointModule):
def __init__(self):
super(KissGPModel, self).__init__(grid_size=300, grid_bounds=[(0, 1)])
self.mean_module = ConstantMean(constant_bounds=(-1, 1))
covar_module = RBFKernel(log_lengthscale_bounds=(-100, 100))
covar_module.log_lengthscale.data = torch.FloatTensor([-2])
self.covar_module = covar_module
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return GaussianRandomVariable(mean_x, covar_x)
model = KissGPModel()
n = 100
d = 4
lazy_var_list = []
lazy_var_eval_list = []
for i in range(d):
x = Variable(torch.rand(n))
y = Variable(torch.rand(n))
model.condition(x, y)
toeplitz_var = model(x).covar()
lazy_var_list.append(toeplitz_var)
lazy_var_eval_list.append(toeplitz_var.evaluate().data)
mul_lazy_var = MulLazyVariable(*lazy_var_list, matmul_mode='approximate', max_iter=30)
mul_lazy_var_eval = torch.ones(n, n)
for i in range(d):
mul_lazy_var_eval *= (lazy_var_eval_list[i].matmul(torch.eye(lazy_var_eval_list[i].size()[0])))
vec = torch.randn(n)
actual = mul_lazy_var_eval.matmul(vec)
res = mul_lazy_var.matmul(Variable(vec)).data
assert torch.norm(actual - res) / torch.norm(actual) < 1e-2
评论列表
文章目录