def backward(ctx, grad_output):
input, grid = ctx.saved_tensors
if cudnn.is_acceptable(input):
grad_input = input.new(input.size())
grad_grid = grid.new(grid.size())
grid = grid.contiguous()
if 0 in input.stride():
input = input.contiguous()
torch._C._cudnn_grid_sampler_backward(input, grad_input,
grid, grad_grid,
grad_output)
else:
backend = type2backend[type(input)]
grad_input = input.new(input.size())
grad_grid = grid.new(grid.size())
backend.SpatialGridSamplerBilinear_updateGradInput(
backend.library_state, input, grad_input,
grid, grad_grid, grad_output)
return grad_input, grad_grid
评论列表
文章目录