def segment_softmax(scores, segment_ids):
"""Given scores and a partition, converts scores to probs by performing
softmax over all rows within a partition."""
# Subtract max
num_segments = tf.reduce_max(segment_ids) + 1
if len(scores.get_shape()) == 2:
max_per_partition = tf.unsorted_segment_max(tf.reduce_max(scores, axis=1), segment_ids, num_segments)
scores -= tf.expand_dims(tf.gather(max_per_partition, segment_ids), axis=1)
else:
max_per_partition = tf.unsorted_segment_max(scores, segment_ids, num_segments)
scores -= tf.gather(max_per_partition, segment_ids)
# Compute probs
scores_exp = tf.exp(scores)
if len(scores.get_shape()) == 2:
scores_exp_sum_per_partition = tf.unsorted_segment_sum(tf.reduce_sum(scores_exp, axis=1), segment_ids,
num_segments)
probs = scores_exp / tf.expand_dims(tf.gather(scores_exp_sum_per_partition, segment_ids), axis=1)
else:
scores_exp_sum_per_partition = tf.unsorted_segment_sum(scores_exp, segment_ids, num_segments)
probs = scores_exp / tf.gather(scores_exp_sum_per_partition, segment_ids)
return probs
评论列表
文章目录