def F_batch_bilinear_interp2d(input, coords):
"""
input : torch.Tensor
size = (N,H,W,C)
coords : torch.Tensor
size = (N,H*W*C,2)
"""
x = torch.clamp(coords[:,:,0], 0, input.size(2)-2)
x0 = x.floor()
x1 = x0 + 1
y = torch.clamp(coords[:,:,1], 0, input.size(3)-2)
y0 = y.floor()
y1 = y0 + 1
stride = torch.LongTensor(input.stride())
x0_ix = x0.mul(stride[2]).long()
x1_ix = x1.mul(stride[2]).long()
y0_ix = y0.mul(stride[3]).long()
y1_ix = y1.mul(stride[3]).long()
input_flat = input.view(input.size(0),-1).contiguous()
vals_00 = input_flat.gather(1, x0_ix.add(y0_ix).detach())
vals_10 = input_flat.gather(1, x1_ix.add(y0_ix).detach())
vals_01 = input_flat.gather(1, x0_ix.add(y1_ix).detach())
vals_11 = input_flat.gather(1, x1_ix.add(y1_ix).detach())
xd = x - x0
yd = y - y0
xm = 1 - xd
ym = 1 - yd
x_mapped = (vals_00.mul(xm).mul(ym) +
vals_10.mul(xd).mul(ym) +
vals_01.mul(xm).mul(yd) +
vals_11.mul(xd).mul(yd))
return x_mapped.view_as(input)
评论列表
文章目录