def th_nearest_interp2d(input, coords):
"""
2d nearest neighbor interpolation th.Tensor
"""
# take clamp of coords so they're in the image bounds
x = th.clamp(coords[:,:,0], 0, input.size(1)-1).round()
y = th.clamp(coords[:,:,1], 0, input.size(2)-1).round()
stride = th.LongTensor(input.stride())
x_ix = x.mul(stride[1]).long()
y_ix = y.mul(stride[2]).long()
input_flat = input.view(input.size(0),-1)
mapped_vals = input_flat.gather(1, x_ix.add(y_ix))
return mapped_vals.view_as(input)
评论列表
文章目录