def forward(self, input, weight, bias=None):
self._backend = type2backend[type(input)]
# TODO: free buffers when not needed
self.buffer1 = input.new()
self.buffer2 = input.new()
output = input.new()
self.with_bias = bias is not None
if torch.typename(input) == 'torch.cuda.FloatTensor':
self._backend.VolumetricConvolution_updateOutput(
self._backend.library_state, input, output, weight, bias,
self.buffer1, self.buffer2, *self.additional_args[3:])
else:
self._backend.VolumetricConvolutionMM_updateOutput(
self._backend.library_state, input, output, weight,
bias, self.buffer1, *self.additional_args)
if self.with_bias:
self.save_for_backward(input, weight, bias)
else:
self.save_for_backward(input, weight)
return output
评论列表
文章目录