vision.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:pytorch 作者: ezyang 项目源码 文件源码
def forward(ctx, theta, size):
        assert type(size) == torch.Size
        N, C, H, W = size
        ctx.size = size
        if theta.is_cuda:
            ctx.is_cuda = True
            AffineGridGenerator._enforce_cudnn(theta)
            grid = theta.new(N, H, W, 2)
            theta = theta.contiguous()
            torch._C._cudnn_affine_grid_generator_forward(theta, grid, N, C, H, W)
        else:
            ctx.is_cuda = False
            base_grid = theta.new(N, H, W, 3)
            linear_points = torch.linspace(-1, 1, W) if W > 1 else torch.Tensor([-1])
            base_grid[:, :, :, 0] = torch.ger(torch.ones(H), linear_points).expand_as(base_grid[:, :, :, 0])
            linear_points = torch.linspace(-1, 1, H) if H > 1 else torch.Tensor([-1])
            base_grid[:, :, :, 1] = torch.ger(linear_points, torch.ones(W)).expand_as(base_grid[:, :, :, 1])
            base_grid[:, :, :, 2] = 1
            ctx.base_grid = base_grid
            grid = torch.bmm(base_grid.view(N, H * W, 3), theta.transpose(1, 2))
            grid = grid.view(N, H, W, 2)
        return grid
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号