conv.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:pytorch-dist 作者: apaszke 项目源码 文件源码
def forward(self, input, weight, bias=None):
        output = input.new(*self._output_size(input, weight))
        if bias is not None:
            self.save_for_backward(input, weight, bias)
        else:
            self.save_for_backward(input, weight)

        if cudnn.is_acceptable(input):
            self._cudnn_info = torch._C._cudnn_convolution_forward(
                input, weight, bias, output, self.pad[0], self.pad[1],
                self.stride[0], self.stride[1], self.groups, cudnn.benchmark)
        else:
            # TODO: implement groups for THNN
            if self.groups != 1:
                raise ValueError('THNN does not support groups')
            backend = type2backend[type(input)]
            self._finput = input.new()
            self._fgrad_input = input.new()
            backend.SpatialConvolutionMM_updateOutput(
                backend.library_state, input, output, weight, bias,
                self._finput, self._fgrad_input, weight.size(3), weight.size(2),
                self.stride[1], self.stride[0], self.pad[1], self.pad[0])

        return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号