conv.py 文件源码

python
阅读 43 收藏 0 点赞 0 评论 0

项目:pytorch-dist 作者: apaszke 项目源码 文件源码
def forward(self, input, weight, bias=None):
        self._backend = type2backend[type(input)]
        # TODO: free buffers when not needed
        self.buffer1 = input.new()
        self.buffer2 = input.new()
        output = input.new()
        self.with_bias = bias is not None
        if torch.typename(input) == 'torch.cuda.FloatTensor':
            self._backend.VolumetricConvolution_updateOutput(
                self._backend.library_state, input, output, weight, bias,
                self.buffer1, self.buffer2, *self.additional_args[3:])
        else:
            self._backend.VolumetricConvolutionMM_updateOutput(
                self._backend.library_state, input, output, weight,
                bias, self.buffer1, *self.additional_args)
        if self.with_bias:
            self.save_for_backward(input, weight, bias)
        else:
            self.save_for_backward(input, weight)
        return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号