numpy_extensions_tutorial.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:tutorials 作者: pytorch 项目源码 文件源码
def backward(self, grad_output):
        input, filter = self.saved_tensors
        grad_input = convolve2d(grad_output.numpy(), filter.t().numpy(), mode='full')
        grad_filter = convolve2d(input.numpy(), grad_output.numpy(), mode='valid')
        return torch.FloatTensor(grad_input), torch.FloatTensor(grad_filter)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号