def backward(self, grad_output):
tensors = self.saved_tensors
if len(tensors) == 2:
input, weight = tensors
bias = None
else:
input, weight, bias = tensors
grad_input, grad_weight, grad_bias = None, None, None
if cudnn.is_acceptable(input):
if self.needs_input_grad[0]:
grad_input = input.new().resize_as_(input)
torch._C._cudnn_convolution_backward_data(
grad_output, grad_input, weight, self._cudnn_info,
cudnn.benchmark)
if self.needs_input_grad[1]:
grad_weight = weight.new().resize_as_(weight)
torch._C._cudnn_convolution_backward_filter(
grad_output, input, grad_weight, self._cudnn_info,
cudnn.benchmark)
if bias is not None and self.needs_input_grad[2]:
grad_bias = bias.new().resize_as_(bias)
torch._C._cudnn_convolution_backward_bias(
grad_output, grad_bias, self._cudnn_info)
else:
backend = type2backend[type(input)]
if self.needs_input_grad[0]:
grad_input = input.new().resize_as_(input).zero_()
backend.SpatialConvolutionMM_updateGradInput(
backend.library_state, input, grad_output, grad_input,
weight, self._finput, self._fgrad_input, weight.size(3),
weight.size(2), self.stride[1], self.stride[0], self.pad[1],
self.pad[0])
if any(self.needs_input_grad[1:]):
grad_weight = weight.new().resize_as_(weight).zero_()
if bias is not None and self.needs_input_grad[2]:
grad_bias = bias.new().resize_as_(bias).zero_()
else:
grad_bias = None
backend.SpatialConvolutionMM_accGradParameters(
backend.library_state, input, grad_output, grad_weight,
grad_bias, self._finput, self._fgrad_input, weight.size(3),
weight.size(2), self.stride[1], self.stride[0], self.pad[1],
self.pad[0], 1)
if bias is not None:
return grad_input, grad_weight, grad_bias
else:
return grad_input, grad_weight
评论列表
文章目录