def segment_indices(segment_ids, name=None):
"""Returns a `Tensor` of indices within each segment.
segment_ids should be a sequence of non-decreasing non-negative integers that
define a set of segments, e.g. [0, 0, 1, 2, 2, 2] defines 3 segments of length
2, 1 and 3. The return value is a `Tensor` containing the indices within each
segment.
Example input: [0, 0, 1, 2, 2, 2]
Example output: [0, 1, 0, 0, 1, 2]
Args:
segment_ids: A 1-d `Tensor` containing an non-decreasing sequence of
non-negative integers with type `tf.int32` or `tf.int64`.
name: (Optional) A name for this operation.
Returns:
A `Tensor` containing the indices within each segment.
"""
with tf.name_scope(name, 'segment_indices'):
segment_lengths = tf.segment_sum(tf.ones_like(segment_ids), segment_ids)
segment_starts = tf.gather(tf.concat([[0], tf.cumsum(segment_lengths)], 0),
segment_ids)
return (tf.range(tf.size(segment_ids, out_type=segment_ids.dtype)) -
segment_starts)
评论列表
文章目录