def _thin_stack_update_gradient(op, stack_grad, *rest):
stack = op.inputs[2]
batch_size = op.inputs[4].get_shape().as_list()[0]
t = op.get_attr("timestep")
# We usually slice off the head of the stack output in feedforward and
# send it off to downstream computation. The Slice feedforward op will
# generate a sparse gradient in the backward pass. Nix this sparsity
# at the very start.
if isinstance(stack_grad, ops.IndexedSlices):
# Trick: re-use our stack structure to store new gradients.
# Recover the original stack variable from the lookup/update chain.
stack = _fetch_stack(stack)
stack = tf.assign(stack, tf.zeros_like(stack))
stack = tf.scatter_update(stack, stack_grad.indices, stack_grad.values)
stack_grad = stack
with tf.control_dependencies([stack_grad]):
input_grad = tf.slice(stack_grad, [t * batch_size, 0], [batch_size, -1])
return input_grad, None, stack_grad, None, None, None
评论列表
文章目录