interpolation.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:gpytorch 作者: jrg365 项目源码 文件源码
def interpolate(self, x_grid, x_target):
        interp_points = range(-2, 2)
        num_grid_points = len(x_grid)
        num_target_points = len(x_target)
        num_coefficients = len(interp_points)

        grid_delta = x_grid[1] - x_grid[0]

        lower_grid_pt_idxs = torch.floor((x_target - x_grid[0]) / grid_delta).squeeze()
        lower_pt_rel_dists = (x_target - x_grid[0]) / grid_delta - lower_grid_pt_idxs
        lower_grid_pt_idxs = lower_grid_pt_idxs - interp_points[-1]
        C = x_target.new(num_target_points, num_coefficients).zero_()

        for i in range(num_coefficients):
            scaled_dist = lower_pt_rel_dists + interp_points[-i - 1]
            C[:, i] = self._cubic_interpolation_kernel(scaled_dist)

        # Find points who's closest lower grid point is the first grid point
        # This corresponds to a boundary condition that we must fix manually.
        left_boundary_pts = torch.nonzero(lower_grid_pt_idxs < 1)
        num_left = len(left_boundary_pts)

        if num_left > 0:
            left_boundary_pts.squeeze_(1)
            x_grid_first = x_grid[:num_coefficients].unsqueeze(1).t().expand(num_left, num_coefficients)

            grid_targets = x_target[left_boundary_pts].unsqueeze(1).expand(num_left, num_coefficients)
            dists = torch.abs(x_grid_first - grid_targets)
            closest_from_first = torch.min(dists, 1)[1]

            for i in range(num_left):
                C[left_boundary_pts[i], :] = 0
                C[left_boundary_pts[i], closest_from_first[i]] = 1
                lower_grid_pt_idxs[left_boundary_pts[i]] = 0

        right_boundary_pts = torch.nonzero(lower_grid_pt_idxs > num_grid_points - num_coefficients)
        num_right = len(right_boundary_pts)

        if num_right > 0:
            right_boundary_pts.squeeze_(1)
            x_grid_last = x_grid[-num_coefficients:].unsqueeze(1).t().expand(num_right, num_coefficients)

            grid_targets = x_target[right_boundary_pts].unsqueeze(1).expand(num_right, num_coefficients)
            dists = torch.abs(x_grid_last - grid_targets)
            closest_from_last = torch.min(dists, 1)[1]

            for i in range(num_right):
                C[right_boundary_pts[i], :] = 0
                C[right_boundary_pts[i], closest_from_last[i]] = 1
                lower_grid_pt_idxs[right_boundary_pts[i]] = num_grid_points - num_coefficients

        J = x_grid.new(num_target_points, num_coefficients).zero_()
        for i in range(num_coefficients):
            J[:, i] = lower_grid_pt_idxs + i

        J = J.long()
        return J, C
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号