def backward_extended(self, grad_output, grad_hy):
input, hx, weight, output = self.saved_tensors
input = input.contiguous()
grad_input, grad_weight, grad_hx = None, None, None
assert cudnn.is_acceptable(input)
grad_input = input.new()
if torch.is_tensor(hx):
grad_hx = input.new()
else:
grad_hx = tuple(h.new() for h in hx)
if self.retain_variables:
self._reserve_clone = self.reserve.clone()
cudnn.rnn.backward_grad(
self,
input,
hx,
weight,
output,
grad_output,
grad_hy,
grad_input,
grad_hx)
if any(self.needs_input_grad[1:]):
grad_weight = [tuple(w.new().resize_as_(w) for w in layer_weight) for layer_weight in weight]
cudnn.rnn.backward_weight(
self,
input,
hx,
output,
weight,
grad_weight)
else:
grad_weight = [(None,) * len(layer_weight) for layer_weight in weight]
if self.retain_variables:
self.reserve = self._reserve_clone
del self._reserve_clone
return grad_input, grad_weight, grad_hx
评论列表
文章目录