def backward(ctx, grad_grid):
N, C, H, W = ctx.size
assert grad_grid.size() == torch.Size([N, H, W, 2])
assert ctx.is_cuda == grad_grid.is_cuda
if grad_grid.is_cuda:
AffineGridGenerator._enforce_cudnn(grad_grid)
grad_theta = grad_grid.new(N, 2, 3)
grad_grid = grad_grid.contiguous()
torch._C._cudnn_affine_grid_generator_backward(grad_theta, grad_grid,
N, C, H, W)
else:
base_grid = ctx.base_grid
grad_theta = torch.bmm(
base_grid.view(N, H * W, 3).transpose(1, 2),
grad_grid.view(N, H * W, 2))
grad_theta = grad_theta.transpose(1, 2)
return grad_theta, None
评论列表
文章目录