mobilenetdet.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:MobileNet 作者: Zehaos 项目源码 文件源码
def encode_annos(image, labels, bboxes, anchors, num_classes):
  """Encode annotations for losses computations.
  All the output tensors have a fix shape(none dynamic dimention).

  Args:
    image: 4-D with shape `[H, W, C]`.
    b_labels: 2-D with shape `[num_bounding_boxes]`.
    b_bboxes: 3-D with shape `[num_bounding_boxes, 4]`. Scaled.
    anchors: 4-D tensor with shape `[fea_h, fea_w, num_anchors, 4]`

  Returns:
    input_mask: 2-D with shape `[num_anchors, 1]`, indicate which anchor to be used to cal loss.
    labels_input: 2-D with shape `[num_anchors, num_classes]`, one hot encode for every anchor.
    box_delta_input: 2-D with shape `[num_anchors, 4]`.
    box_input: 2-D with shape '[num_anchors, 4]'.
  """
  anchors_shape = anchors.get_shape().as_list()
  fea_h = anchors_shape[0]
  fea_w = anchors_shape[1]
  num_anchors = anchors_shape[2] * fea_h * fea_w
  anchors = tf.reshape(anchors, [num_anchors, 4])  # reshape anchors

  # Cal iou, find the target anchor
  _anchors = xywh_to_yxyx(anchors)
  ious, indices = batch_iou_fast(_anchors, bboxes)
  indices = tf.reshape(indices, shape=[-1, 1])

  target_anchors = tf.gather(anchors, indices)
  target_anchors = tf.squeeze(target_anchors, axis=1)
  delta = batch_delta(yxyx_to_xywh_(bboxes), target_anchors)

  # bbox
  box_input = tf.scatter_nd(
    indices,
    bboxes,
    shape=[num_anchors, 4]
  )
  # label
  labels_input = tf.scatter_nd(
    indices,
    tf.one_hot(labels, num_classes),
    shape=[num_anchors, num_classes]
  )
  # anchor mask
  onehot_anchor = tf.one_hot(indices, num_anchors)
  onehot_anchor = tf.squeeze(onehot_anchor, axis=1)
  print("indices shape:", indices.get_shape().as_list())
  print("one hot anchors shape:", onehot_anchor.get_shape().as_list())
  input_mask = tf.reduce_sum(onehot_anchor, axis=0)
  input_mask = tf.reshape(input_mask, shape=[-1, 1])
  # delta
  box_delta_input = tf.scatter_nd(
    indices,
    delta,
    shape=[num_anchors, 4]
  )

  return input_mask, labels_input, box_delta_input, box_input


# TODO(shizehao): align anchor center to the grid
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号