def forward_extended(self, input, weight, hx):
assert cudnn.is_acceptable(input)
output = input.new()
if torch.is_tensor(hx):
hy = hx.new()
else:
hy = tuple(h.new() for h in hx)
cudnn.rnn.forward(self, input, hx, weight, output, hy)
self.save_for_backward(input, hx, weight, output)
return output, hy
评论列表
文章目录