def __init__(self, grid_size, grid_bounds):
grid = torch.zeros(len(grid_bounds), grid_size)
for i in range(len(grid_bounds)):
grid_diff = float(grid_bounds[i][1] - grid_bounds[i][0]) / (grid_size - 2)
grid[i] = torch.linspace(grid_bounds[i][0] - grid_diff,
grid_bounds[i][1] + grid_diff,
grid_size)
inducing_points = torch.zeros(int(pow(grid_size, len(grid_bounds))), len(grid_bounds))
prev_points = None
for i in range(len(grid_bounds)):
for j in range(grid_size):
inducing_points[j * grid_size ** i:(j + 1) * grid_size ** i, i].fill_(grid[i, j])
if prev_points is not None:
inducing_points[j * grid_size ** i:(j + 1) * grid_size ** i, :i].copy_(prev_points)
prev_points = inducing_points[:grid_size ** (i + 1), :(i + 1)]
super(GridInducingPointModule, self).__init__(inducing_points)
self.grid_size = grid_size
self.grid_bounds = grid_bounds
self.register_buffer('grid', grid)
评论列表
文章目录