def _get_bbox_pred(self, proposed_boxes, gt_boxes_per_class):
"""Computes valid bbox_pred from proposals and gt_boxes for each class.
Args:
proposed_boxes: Tensor with shape (num_proposals, 5).
gt_boxes_per_class: Tensor holding the ground truth boxes for each
class. Has shape (num_classes, num_gt_boxes_per_class, 4).
Returns:
A tensor with shape (num_proposals, num_classes * 4), holding the
correct bbox_preds.
"""
def bbox_encode(gt_boxes):
return encode(
proposed_boxes, gt_boxes
)
bbox_pred_tensor = tf.map_fn(
bbox_encode, gt_boxes_per_class,
dtype=tf.float32
)
# We need to explicitly unstack the tensor so that tf.concat works
# properly.
bbox_pred_list = tf.unstack(bbox_pred_tensor)
return tf.concat(bbox_pred_list, 1)
评论列表
文章目录