beam_search.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:seq2seq 作者: google 项目源码 文件源码
def length_penalty(sequence_lengths, penalty_factor):
  """Calculates the length penalty according to
  https://arxiv.org/abs/1609.08144

   Args:
    sequence_lengths: The sequence length of all hypotheses, a tensor
      of shape [beam_size, vocab_size].
    penalty_factor: A scalar that weights the length penalty.

  Returns:
    The length penalty factor, a tensor fo shape [beam_size].
   """
  return tf.div((5. + tf.to_float(sequence_lengths))**penalty_factor, (5. + 1.)
                **penalty_factor)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号