def retain_boxes_above_threshold(
boxes, labels, label_scores, masks=None, keypoints=None, threshold=0.0):
"""Retains boxes whose label score is above a given threshold.
If the label score for a box is missing (represented by NaN), the box is
retained. The boxes that don't pass the threshold will not appear in the
returned tensor.
Args:
boxes: float32 tensor of shape [num_instance, 4] representing boxes
location in normalized coordinates.
labels: rank 1 int32 tensor of shape [num_instance] containing the object
classes.
label_scores: float32 tensor of shape [num_instance] representing the
score for each box.
masks: (optional) rank 3 float32 tensor with shape
[num_instances, height, width] containing instance masks. The masks are of
the same height, width as the input `image`.
keypoints: (optional) rank 3 float32 tensor with shape
[num_instances, num_keypoints, 2]. The keypoints are in y-x normalized
coordinates.
threshold: scalar python float.
Returns:
retained_boxes: [num_retained_instance, 4]
retianed_labels: [num_retained_instance]
retained_label_scores: [num_retained_instance]
If masks, or keypoints are not None, the function also returns:
retained_masks: [num_retained_instance, height, width]
retained_keypoints: [num_retained_instance, num_keypoints, 2]
"""
with tf.name_scope('RetainBoxesAboveThreshold',
values=[boxes, labels, label_scores]):
indices = tf.where(
tf.logical_or(label_scores > threshold, tf.is_nan(label_scores)))
indices = tf.squeeze(indices, axis=1)
retained_boxes = tf.gather(boxes, indices)
retained_labels = tf.gather(labels, indices)
retained_label_scores = tf.gather(label_scores, indices)
result = [retained_boxes, retained_labels, retained_label_scores]
if masks is not None:
retained_masks = tf.gather(masks, indices)
result.append(retained_masks)
if keypoints is not None:
retained_keypoints = tf.gather(keypoints, indices)
result.append(retained_keypoints)
return result
评论列表
文章目录