def th_nearest_interp3d(input, coords):
"""
2d nearest neighbor interpolation th.Tensor
"""
# take clamp of coords so they're in the image bounds
coords[:,0] = th.clamp(coords[:,0], 0, input.size(1)-1).round()
coords[:,1] = th.clamp(coords[:,1], 0, input.size(2)-1).round()
coords[:,2] = th.clamp(coords[:,2], 0, input.size(3)-1).round()
stride = th.LongTensor(input.stride())[1:].float()
idx = coords.mv(stride).long()
input_flat = th_flatten(input)
mapped_vals = input_flat[idx]
return mapped_vals.view_as(input)
评论列表
文章目录