def backward_extended(self, grad_output, grad_hy):
input, hx, weight, output = self.saved_tensors
grad_input, grad_weight, grad_hx = None, None, None
assert(cudnn.is_acceptable(input))
grad_input = input.new()
grad_weight = input.new()
grad_hx = input.new()
if torch.is_tensor(hx):
grad_hx = input.new()
else:
grad_hx = tuple(h.new() for h in hx)
cudnn.rnn.backward_grad(
self,
input,
hx,
weight,
output,
grad_output,
grad_hy,
grad_input,
grad_hx)
if self.needs_input_grad[1]:
grad_weight = [tuple(w.new().resize_as_(w).zero_() for w in layer_weight) for layer_weight in weight]
cudnn.rnn.backward_weight(
self,
input,
hx,
output,
weight,
grad_weight)
return grad_input, grad_weight, grad_hx
评论列表
文章目录