def make_node(self, desc, x, y, dy, dhy, dcy, w, hx, cx, reserve):
# We trust the callers here
xshp = as_scalar(x.shape[2]).astype('uint64')
inputs = [desc, xshp, y, dy, w, hx, reserve]
outputs = [reserve.type(), x.type(), hx.type()]
if self.rnn_mode == 'lstm':
inputs.append(cx)
outputs.append(cx.type())
if self.grad_h:
inputs.append(dhy)
if self.grad_c:
inputs.append(dcy)
return Apply(self, inputs, outputs)
# We have special requirements so this is hooking into COp
评论列表
文章目录