def th_map_coordinates(input, coords, order=1):
"""Tensorflow verion of scipy.ndimage.map_coordinates
Note that coords is transposed and only 2D is supported
Parameters
----------
input : tf.Tensor. shape = (s, s)
coords : tf.Tensor. shape = (n_points, 2)
"""
assert order == 1
input_size = input.size(0)
coords = torch.clamp(coords, 0, input_size - 1)
coords_lt = coords.floor().long()
coords_rb = coords.ceil().long()
coords_lb = torch.stack([coords_lt[:, 0], coords_rb[:, 1]], 1)
coords_rt = torch.stack([coords_rb[:, 0], coords_lt[:, 1]], 1)
vals_lt = th_gather_2d(input, coords_lt.detach())
vals_rb = th_gather_2d(input, coords_rb.detach())
vals_lb = th_gather_2d(input, coords_lb.detach())
vals_rt = th_gather_2d(input, coords_rt.detach())
coords_offset_lt = coords - coords_lt.type(coords.data.type())
vals_t = vals_lt + (vals_rt - vals_lt) * coords_offset_lt[:, 0]
vals_b = vals_lb + (vals_rb - vals_lb) * coords_offset_lt[:, 0]
mapped_vals = vals_t + (vals_b - vals_t) * coords_offset_lt[:, 1]
return mapped_vals
评论列表
文章目录