def sort_by_field(boxlist, field, order=SortOrder.descend, scope=None):
"""Sort boxes and associated fields according to a scalar field.
A common use case is reordering the boxes according to descending scores.
Args:
boxlist: BoxList holding N boxes.
field: A BoxList field for sorting and reordering the BoxList.
order: (Optional) descend or ascend. Default is descend.
scope: name scope.
Returns:
sorted_boxlist: A sorted BoxList with the field in the specified order.
Raises:
ValueError: if specified field does not exist
ValueError: if the order is not either descend or ascend
"""
with tf.name_scope(scope, 'SortByField'):
if order != SortOrder.descend and order != SortOrder.ascend:
raise ValueError('Invalid sort order')
field_to_sort = boxlist.get_field(field)
if len(field_to_sort.shape.as_list()) != 1:
raise ValueError('Field should have rank 1')
num_boxes = boxlist.num_boxes()
num_entries = tf.size(field_to_sort)
length_assert = tf.Assert(
tf.equal(num_boxes, num_entries),
['Incorrect field size: actual vs expected.', num_entries, num_boxes])
with tf.control_dependencies([length_assert]):
# TODO: Remove with tf.device when top_k operation runs correctly on GPU.
with tf.device('/cpu:0'):
_, sorted_indices = tf.nn.top_k(field_to_sort, num_boxes, sorted=True)
if order == SortOrder.ascend:
sorted_indices = tf.reverse_v2(sorted_indices, [0])
return gather(boxlist, sorted_indices)
评论列表
文章目录