def forward(self, input, weight, bias=None):
output = input.new(*self._output_size(input, weight))
if bias is not None:
self.save_for_backward(input, weight, bias)
else:
self.save_for_backward(input, weight)
if cudnn.is_acceptable(input):
self._cudnn_info = torch._C._cudnn_convolution_forward(
input, weight, bias, output, self.pad[0], self.pad[1],
self.stride[0], self.stride[1], self.groups, cudnn.benchmark)
else:
# TODO: implement groups for THNN
if self.groups != 1:
raise ValueError('THNN does not support groups')
backend = type2backend[type(input)]
self._finput = input.new()
self._fgrad_input = input.new()
backend.SpatialConvolutionMM_updateOutput(
backend.library_state, input, output, weight, bias,
self._finput, self._fgrad_input, weight.size(3), weight.size(2),
self.stride[1], self.stride[0], self.pad[1], self.pad[0])
return output
评论列表
文章目录