def shift_thin_stack(thin_stack, thin_stack_head_next, batch_size,
max_num_concepts, decoder_position,
prev_transition_state):
"""Applies shift to the thin stack and its head if in shift state."""
# Head points to item after stack top, so always update the stack entry.
new_thin_stack = write_thin_stack(thin_stack, thin_stack_head_next,
decoder_position, batch_size, max_num_concepts)
# Push if previous transition state is shift (or pointer shift).
stack_head_updates = tf.sparse_to_dense(tf.pack(
[data_utils.GEN_STATE]),
tf.pack([data_utils.NUM_TR_STATES]), 1)
new_thin_stack_head_next = tf.add(thin_stack_head_next,
tf.gather(stack_head_updates, prev_transition_state))
return new_thin_stack, new_thin_stack_head_next
评论列表
文章目录