vision.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:pytorch 作者: ezyang 项目源码 文件源码
def backward(ctx, grad_grid):
        N, C, H, W = ctx.size
        assert grad_grid.size() == torch.Size([N, H, W, 2])
        assert ctx.is_cuda == grad_grid.is_cuda
        if grad_grid.is_cuda:
            AffineGridGenerator._enforce_cudnn(grad_grid)
            grad_theta = grad_grid.new(N, 2, 3)
            grad_grid = grad_grid.contiguous()
            torch._C._cudnn_affine_grid_generator_backward(grad_theta, grad_grid,
                                                           N, C, H, W)
        else:
            base_grid = ctx.base_grid
            grad_theta = torch.bmm(
                base_grid.view(N, H * W, 3).transpose(1, 2),
                grad_grid.view(N, H * W, 2))
            grad_theta = grad_theta.transpose(1, 2)

        return grad_theta, None
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号