def segment_logsumexp(xs, segments):
""" Similar tf.segment_sum but compute logsumexp rather then sum """
# Stop gradients following the implementation of tf.reduce_logsumexp
maxs = tf.stop_gradient(tf.reduce_max(xs, axis=1))
segment_maxes = tf.segment_max(maxs, segments)
xs -= tf.expand_dims(tf.gather(segment_maxes, segments), 1)
sums = tf.reduce_sum(tf.exp(xs), axis=1)
return tf.log(tf.segment_sum(sums, segments)) + segment_maxes
评论列表
文章目录