segment.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:jack 作者: uclmr 项目源码 文件源码
def segment_argmax(input, segment_ids):
    """Computes row and col indices Tensors of the segment max in the 2D input."""

    with tf.name_scope("segment_argmax"):
        num_partitions = tf.reduce_max(segment_ids) + 1
        is_max = segment_is_max(input, segment_ids)

        # The current is_max could still contain multiple True entries per
        # partition. As long as they are in the same row, that is not a problem.
        # However, we do need to remove duplicate Trues in the same partition
        # in multiple rows.
        # For that, we'll multiply is_max with the row indices + 1 and perform
        # segment_is_max() again.

        rows = tf.shape(input)[0]
        cols = tf.shape(input)[1]
        row_indices = tf.tile(tf.expand_dims(tf.range(rows), 1), [1, cols])
        is_max = segment_is_max(tf.cast(is_max, tf.int32) * (row_indices + 1), segment_ids)

        # Get selected rows and columns
        row_selected = tf.reduce_any(is_max, axis=1)
        row_indices = tf.squeeze(tf.where(row_selected))
        rows_selected = tf.reduce_sum(tf.cast(row_selected, tf.int64))

        # Assert rows_selected is correct & ensure row_indices is always 1D
        with tf.control_dependencies([tf.assert_equal(rows_selected, num_partitions)]):
            row_indices = tf.reshape(row_indices, [-1])

        selected_rows_is_max = tf.gather(is_max, row_indices)
        col_indices = tf.argmax(tf.cast(selected_rows_is_max, tf.int64), axis=1)

        # Pack indices
        return row_indices, col_indices
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号