vision.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:pytorch 作者: ezyang 项目源码 文件源码
def forward(ctx, input, grid):
        ctx.save_for_backward(input, grid)
        grid_sz = grid.size()
        if cudnn.is_acceptable(input):
            output = input.new(grid_sz[0], input.size(1), grid_sz[1], grid_sz[2])
            grid = grid.contiguous()
            if 0 in input.stride():
                input = input.contiguous()
            torch._C._cudnn_grid_sampler_forward(input, grid, output)
        else:
            backend = type2backend[type(input)]
            output = input.new(grid_sz[0], input.size(1), grid_sz[1], grid_sz[2])
            backend.SpatialGridSamplerBilinear_updateOutput(backend.library_state, input, grid, output)
        return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号