segment.py 文件源码

python
阅读 40 收藏 0 点赞 0 评论 0

项目:jack 作者: uclmr 项目源码 文件源码
def segment_softmax(scores, segment_ids):
    """Given scores and a partition, converts scores to probs by performing
    softmax over all rows within a partition."""

    # Subtract max
    num_segments = tf.reduce_max(segment_ids) + 1
    if len(scores.get_shape()) == 2:
        max_per_partition = tf.unsorted_segment_max(tf.reduce_max(scores, axis=1), segment_ids, num_segments)
        scores -= tf.expand_dims(tf.gather(max_per_partition, segment_ids), axis=1)
    else:
        max_per_partition = tf.unsorted_segment_max(scores, segment_ids, num_segments)
        scores -= tf.gather(max_per_partition, segment_ids)

    # Compute probs
    scores_exp = tf.exp(scores)
    if len(scores.get_shape()) == 2:
        scores_exp_sum_per_partition = tf.unsorted_segment_sum(tf.reduce_sum(scores_exp, axis=1), segment_ids,
                                                               num_segments)
        probs = scores_exp / tf.expand_dims(tf.gather(scores_exp_sum_per_partition, segment_ids), axis=1)
    else:
        scores_exp_sum_per_partition = tf.unsorted_segment_sum(scores_exp, segment_ids, num_segments)
        probs = scores_exp / tf.gather(scores_exp_sum_per_partition, segment_ids)

    return probs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号