def segment_argmax(input, segment_ids):
"""Computes row and col indices Tensors of the segment max in the 2D input."""
with tf.name_scope("segment_argmax"):
num_partitions = tf.reduce_max(segment_ids) + 1
is_max = segment_is_max(input, segment_ids)
# The current is_max could still contain multiple True entries per
# partition. As long as they are in the same row, that is not a problem.
# However, we do need to remove duplicate Trues in the same partition
# in multiple rows.
# For that, we'll multiply is_max with the row indices + 1 and perform
# segment_is_max() again.
rows = tf.shape(input)[0]
cols = tf.shape(input)[1]
row_indices = tf.tile(tf.expand_dims(tf.range(rows), 1), [1, cols])
is_max = segment_is_max(tf.cast(is_max, tf.int32) * (row_indices + 1), segment_ids)
# Get selected rows and columns
row_selected = tf.reduce_any(is_max, axis=1)
row_indices = tf.squeeze(tf.where(row_selected))
rows_selected = tf.reduce_sum(tf.cast(row_selected, tf.int64))
# Assert rows_selected is correct & ensure row_indices is always 1D
with tf.control_dependencies([tf.assert_equal(rows_selected, num_partitions)]):
row_indices = tf.reshape(row_indices, [-1])
selected_rows_is_max = tf.gather(is_max, row_indices)
col_indices = tf.argmax(tf.cast(selected_rows_is_max, tf.int64), axis=1)
# Pack indices
return row_indices, col_indices
评论列表
文章目录