def _grad_input(self, input, weight, grad_output):
if self.use_cudnn:
grad_input = input.new().resize_as_(input)
if self.transposed:
# ConvTranspose uses the same kernels as regular convolution
# but swaps forward and backward calls
torch._C._cudnn_convolution_forward(
grad_output, weight, grad_input, self._cudnn_info,
cudnn.benchmark)
else:
torch._C._cudnn_convolution_backward_data(
grad_output, grad_input, weight, self._cudnn_info,
cudnn.benchmark)
return grad_input
return self._thnn('grad_input', input, weight, grad_output)
评论列表
文章目录