beam_search.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:seq2seq 作者: google 项目源码 文件源码
def mask_probs(probs, eos_token, finished):
  """Masks log probabilities such that finished beams
  allocate all probability mass to eos. Unfinished beams remain unchanged.

  Args:
    probs: Log probabiltiies of shape `[beam_width, vocab_size]`
    eos_token: An int32 id corresponding to the EOS token to allocate
      probability to
    finished: A boolean tensor of shape `[beam_width]` that specifies which
      elements in the beam are finished already.

  Returns:
    A tensor of shape `[beam_width, vocab_size]`, where unfinished beams
    stay unchanged and finished beams are replaced with a tensor that has all
    probability on the EOS token.
  """
  vocab_size = tf.shape(probs)[1]
  finished_mask = tf.expand_dims(tf.to_float(1. - tf.to_float(finished)), 1)
  # These examples are not finished and we leave them
  non_finished_examples = finished_mask * probs
  # All finished examples are replaced with a vector that has all
  # probability on EOS
  finished_row = tf.one_hot(
      eos_token,
      vocab_size,
      dtype=tf.float32,
      on_value=0.,
      off_value=tf.float32.min)
  finished_examples = (1. - finished_mask) * finished_row
  return finished_examples + non_finished_examples
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号