def forward(ctx, theta, size):
assert type(size) == torch.Size
N, C, H, W = size
ctx.size = size
if theta.is_cuda:
ctx.is_cuda = True
AffineGridGenerator._enforce_cudnn(theta)
grid = theta.new(N, H, W, 2)
theta = theta.contiguous()
torch._C._cudnn_affine_grid_generator_forward(theta, grid, N, C, H, W)
else:
ctx.is_cuda = False
base_grid = theta.new(N, H, W, 3)
linear_points = torch.linspace(-1, 1, W) if W > 1 else torch.Tensor([-1])
base_grid[:, :, :, 0] = torch.ger(torch.ones(H), linear_points).expand_as(base_grid[:, :, :, 0])
linear_points = torch.linspace(-1, 1, H) if H > 1 else torch.Tensor([-1])
base_grid[:, :, :, 1] = torch.ger(linear_points, torch.ones(W)).expand_as(base_grid[:, :, :, 1])
base_grid[:, :, :, 2] = 1
ctx.base_grid = base_grid
grid = torch.bmm(base_grid.view(N, H * W, 3), theta.transpose(1, 2))
grid = grid.view(N, H, W, 2)
return grid
评论列表
文章目录