def mask_probs(probs, eos_token, finished):
"""Masks log probabilities such that finished beams
allocate all probability mass to eos. Unfinished beams remain unchanged.
Args:
probs: Log probabiltiies of shape `[beam_width, vocab_size]`
eos_token: An int32 id corresponding to the EOS token to allocate
probability to
finished: A boolean tensor of shape `[beam_width]` that specifies which
elements in the beam are finished already.
Returns:
A tensor of shape `[beam_width, vocab_size]`, where unfinished beams
stay unchanged and finished beams are replaced with a tensor that has all
probability on the EOS token.
"""
vocab_size = tf.shape(probs)[1]
finished_mask = tf.expand_dims(tf.to_float(1. - tf.to_float(finished)), 1)
# These examples are not finished and we leave them
non_finished_examples = finished_mask * probs
# All finished examples are replaced with a vector that has all
# probability on EOS
finished_row = tf.one_hot(
eos_token,
vocab_size,
dtype=tf.float32,
on_value=0.,
off_value=tf.float32.min)
finished_examples = (1. - finished_mask) * finished_row
return finished_examples + non_finished_examples
评论列表
文章目录