def _masked_softmax(logits, lengths):
"""
Softmax on last axis with proper mask
"""
sequence_mask = tf.expand_dims(
tf.sequence_mask(
lengths, maxlen=tf.shape(logits)[-1], dtype=tf.float32),
dim=1
)
max_logits = tf.reduce_max(logits, axis=-1, keep_dims=True)
masked_logit_exp = tf.exp(logits - max_logits) * sequence_mask
logit_sum = tf.reduce_sum(masked_logit_exp, axis=-1, keep_dims=True)
probs = masked_logit_exp / logit_sum
return probs
decomposable_attention_ops.py 文件源码
python
阅读 29
收藏 0
点赞 0
评论 0
评论列表
文章目录