def mask_decoder_reduce(logit, thin_stack_head_next, logit_size, batch_size):
"""Ensures that we can only reduce when the stack has at least 1 item.
For each batch entry k:
If thin_stack_head_next == 0, #alternatively, or 1.
let logit[k][reduce_index] = -np.inf,
else don't change.
"""
# Allow reduce only if at least 1 item on stack, i.e., pointer >= 2.
update_vals = tf.pack([-np.inf, -np.inf, 0.0])
update_val = tf.gather(update_vals,
tf.minimum(thin_stack_head_next,
2*tf.ones(tf.pack([batch_size]), dtype=tf.int32)))
re_filled = tf.fill(tf.pack([batch_size]),
tf.to_int64(data_utils.REDUCE_ID))
re_inds = tf.transpose(tf.pack(
[tf.to_int64(tf.range(batch_size)), re_filled]))
re_delta = tf.SparseTensor(re_inds, update_val, tf.to_int64(
tf.pack([batch_size, logit_size])))
new_logit = logit + tf.sparse_tensor_to_dense(re_delta)
return new_logit
评论列表
文章目录