def length_penalty(sequence_lengths, penalty_factor):
"""Calculates the length penalty according to
https://arxiv.org/abs/1609.08144
Args:
sequence_lengths: The sequence length of all hypotheses, a tensor
of shape [beam_size, vocab_size].
penalty_factor: A scalar that weights the length penalty.
Returns:
The length penalty factor, a tensor fo shape [beam_size].
"""
return tf.div((5. + tf.to_float(sequence_lengths))**penalty_factor, (5. + 1.)
**penalty_factor)
评论列表
文章目录