preprocessor.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:tensorflow 作者: luyishisi 项目源码 文件源码
def retain_boxes_above_threshold(
    boxes, labels, label_scores, masks=None, keypoints=None, threshold=0.0):
  """Retains boxes whose label score is above a given threshold.

  If the label score for a box is missing (represented by NaN), the box is
  retained. The boxes that don't pass the threshold will not appear in the
  returned tensor.

  Args:
    boxes: float32 tensor of shape [num_instance, 4] representing boxes
      location in normalized coordinates.
    labels: rank 1 int32 tensor of shape [num_instance] containing the object
      classes.
    label_scores: float32 tensor of shape [num_instance] representing the
      score for each box.
    masks: (optional) rank 3 float32 tensor with shape
      [num_instances, height, width] containing instance masks. The masks are of
      the same height, width as the input `image`.
    keypoints: (optional) rank 3 float32 tensor with shape
      [num_instances, num_keypoints, 2]. The keypoints are in y-x normalized
      coordinates.
    threshold: scalar python float.

  Returns:
    retained_boxes: [num_retained_instance, 4]
    retianed_labels: [num_retained_instance]
    retained_label_scores: [num_retained_instance]

    If masks, or keypoints are not None, the function also returns:

    retained_masks: [num_retained_instance, height, width]
    retained_keypoints: [num_retained_instance, num_keypoints, 2]
  """
  with tf.name_scope('RetainBoxesAboveThreshold',
                     values=[boxes, labels, label_scores]):
    indices = tf.where(
        tf.logical_or(label_scores > threshold, tf.is_nan(label_scores)))
    indices = tf.squeeze(indices, axis=1)
    retained_boxes = tf.gather(boxes, indices)
    retained_labels = tf.gather(labels, indices)
    retained_label_scores = tf.gather(label_scores, indices)
    result = [retained_boxes, retained_labels, retained_label_scores]

    if masks is not None:
      retained_masks = tf.gather(masks, indices)
      result.append(retained_masks)

    if keypoints is not None:
      retained_keypoints = tf.gather(keypoints, indices)
      result.append(retained_keypoints)

    return result
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号