def _compute_grad_weight(self, grad_output):
input, weight, bias = self._get_saved_tensors()
# TODO: no zero needed in the future
grad_weight = weight.new().resize_as_(weight).zero_()
grad_bias = bias.new().resize_as_(bias).zero_()
if torch.typename(input) == 'torch.cuda.FloatTensor':
args = self.additional_args[3:] + (1,)
self._backend.VolumetricConvolution_accGradParameters(
self._backend.library_state, input, grad_output, grad_weight,
grad_bias, self.buffer1, self.buffer2,
*args)
else:
self._backend.VolumetricConvolutionMM_accGradParameters(
self._backend.library_state, input, grad_output, grad_weight,
grad_bias, self.buffer1, 1)
return grad_weight, grad_bias
评论列表
文章目录