ops.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:document-qa 作者: allenai 项目源码 文件源码
def segment_logsumexp(xs, segments):
    """ Similar tf.segment_sum but compute logsumexp rather then sum """
    # Stop gradients following the implementation of tf.reduce_logsumexp
    maxs = tf.stop_gradient(tf.reduce_max(xs, axis=1))
    segment_maxes = tf.segment_max(maxs, segments)
    xs -= tf.expand_dims(tf.gather(segment_maxes, segments), 1)
    sums = tf.reduce_sum(tf.exp(xs), axis=1)
    return tf.log(tf.segment_sum(sums, segments)) + segment_maxes
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号