vision.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:pytorch 作者: ezyang 项目源码 文件源码
def backward(ctx, grad_output):
        input, grid = ctx.saved_tensors
        if cudnn.is_acceptable(input):
            grad_input = input.new(input.size())
            grad_grid = grid.new(grid.size())
            grid = grid.contiguous()
            if 0 in input.stride():
                input = input.contiguous()
            torch._C._cudnn_grid_sampler_backward(input, grad_input,
                                                  grid, grad_grid,
                                                  grad_output)
        else:
            backend = type2backend[type(input)]
            grad_input = input.new(input.size())
            grad_grid = grid.new(grid.size())
            backend.SpatialGridSamplerBilinear_updateGradInput(
                backend.library_state, input, grad_input,
                grid, grad_grid, grad_output)
        return grad_input, grad_grid
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号