def _thin_stack_lookup_gradient(op, grad_stack1, grad_stack2, grad_buf_top, _):
stack, buffer, _, _, buffer_cursors, transitions = op.inputs
stack2_ptrs = op.outputs[3]
t = op.get_attr("timestep")
batch_size = buffer_cursors.get_shape().as_list()[0]
num_tokens = buffer.get_shape().as_list()[0] / batch_size
batch_range = math_ops.range(batch_size)
batch_range_i = tf.to_float(batch_range)
grad_stack_name = "grad_stack_%i_%s" % (t, str(uuid.uuid4())[:15])
grad_buffer_name = "grad_buffer_%i_%s" % (t, str(uuid.uuid4())[:15])
grad_stack = gen_state_ops._temporary_variable(stack.get_shape().as_list(), tf.float32, grad_stack_name)
grad_buffer = gen_state_ops._temporary_variable(buffer.get_shape().as_list(), tf.float32, grad_buffer_name)
grad_stack = tf.assign(grad_stack, tf.zeros_like(grad_stack))
grad_buffer = tf.assign(grad_buffer, tf.zeros_like(grad_buffer))
updates = []
# Write grad_stack1 into block (t - 1)
if t >= 1:
in_cursors = (t - 1) * batch_size + batch_range
grad_stack = tf.scatter_add(grad_stack, in_cursors, grad_stack1)
# Write grad_stack2 using stored lookup pointers
grad_stack = floaty_scatter_add(grad_stack, stack2_ptrs * batch_size + batch_range_i, grad_stack2)
# Use buffer_cursors to scatter grads into buffer.
buffer_ptrs = tf.minimum((float) (num_tokens * batch_size) - 1.0,
buffer_cursors * batch_size + batch_range_i)
grad_buffer = floaty_scatter_add(grad_buffer, buffer_ptrs, grad_buf_top)
with tf.control_dependencies([grad_stack, grad_buffer]):
grad_stack = gen_state_ops._destroy_temporary_variable(grad_stack, grad_stack_name)
grad_buffer = gen_state_ops._destroy_temporary_variable(grad_buffer, grad_buffer_name)
with tf.control_dependencies([grad_stack, grad_buffer]):
return grad_stack, grad_buffer, None, None, None, None
# Deprecated custom gradient op.
#@ops.RegisterGradient("ThinStackLookup")
评论列表
文章目录