def forward(ctx, input, grid):
ctx.save_for_backward(input, grid)
grid_sz = grid.size()
if cudnn.is_acceptable(input):
output = input.new(grid_sz[0], input.size(1), grid_sz[1], grid_sz[2])
grid = grid.contiguous()
if 0 in input.stride():
input = input.contiguous()
torch._C._cudnn_grid_sampler_forward(input, grid, output)
else:
backend = type2backend[type(input)]
output = input.new(grid_sz[0], input.size(1), grid_sz[1], grid_sz[2])
backend.SpatialGridSamplerBilinear_updateOutput(backend.library_state, input, grid, output)
return output
评论列表
文章目录