def _update_output(self, input, weight, bias):
self.use_cudnn = cudnn.is_acceptable(input)
if self.use_cudnn and cudnn.version() < 6000:
self.use_cudnn = not self.is_dilated()
if self.use_cudnn:
output = input.new(*self._output_size(input, weight))
if self.transposed:
self._cudnn_info = (
torch._C._cudnn_convolution_transpose_full_forward(
input, weight, bias, output, self.padding, self.stride, self.dilation,
self.groups, cudnn.benchmark))
else:
self._cudnn_info = torch._C._cudnn_convolution_full_forward(
input, weight, bias, output, self.padding, self.stride, self.dilation,
self.groups, cudnn.benchmark)
if not self.requires_grad:
del self._cudnn_info
return output
self._bufs = [[] for g in range(self.groups)]
output = self._thnn('update_output', input, weight, bias)
if not self.requires_grad:
del self._bufs
return output
评论列表
文章目录