def mask_decoder_only_shift(logit, thin_stack_head_next, transition_state_map,
logit_size, batch_size):
"""Ensures that if the stack is empty, has to GEN_STATE (shift transition)
For each batch entry k:
If thin_stack_head_next == 0, #alternatively, or 1.
let logit[k][reduce_index] = -np.inf,
else don't change.
"""
stack_is_empty_bool = tf.less_equal(thin_stack_head_next, 1)
stack_is_empty = tf.select(stack_is_empty_bool,
tf.ones(tf.pack([batch_size]), dtype=tf.int32),
tf.zeros(tf.pack([batch_size]), dtype=tf.int32))
stack_is_empty = tf.reshape(stack_is_empty, [-1, 1])
# Sh and Re states are disallowed (but not root).
state_is_disallowed_updates = tf.sparse_to_dense(
tf.pack([data_utils.RE_STATE, data_utils.ARC_STATE]),
tf.pack([data_utils.NUM_TR_STATES]), 1)
logit_states = tf.gather(transition_state_map, tf.range(logit_size))
state_is_disallowed = tf.gather(state_is_disallowed_updates, logit_states)
state_is_disallowed = tf.reshape(state_is_disallowed, [1, -1])
index_delta = tf.matmul(stack_is_empty, state_is_disallowed) # 1 if disallowed
values = tf.pack([0, -np.inf])
delta = tf.gather(values, index_delta)
new_logit = logit + delta
return new_logit
评论列表
文章目录