def __match_no_miss(self,gt_anchor_labels,gt_anchor_bboxes,gt_anchor_scores,jaccard,gt_labels,gt_bboxes, num_anchors):
#make sure every ground truth box can be matched to at least one anchor box
max_inds = tf.cast(tf.argmax(jaccard, axis=1),tf.int32)
def cond(i,gt_anchors_labels,gt_anchors_bboxes,gt_anchors_scores):
r = tf.less(i, tf.shape(gt_labels)[0])
return r
def body(i,gt_anchors_labels,gt_anchors_bboxes,gt_anchors_scores):
#upate gt_anchors_labels
updates = tf.reshape(gt_labels[i], [-1])
indices = tf.reshape(max_inds[i],[1,-1])
shape = tf.reshape(num_anchors,[-1])
new_labels = tf.scatter_nd(indices, updates, shape)
new_mask = tf.cast(new_labels, tf.bool)
gt_anchors_labels = tf.where(new_mask, new_labels, gt_anchors_labels)
#update gt_anchors_bboxes
updates = tf.reshape(gt_bboxes[i], [1,-1])
indices = tf.reshape(max_inds[i],[1,-1])
shape = tf.shape(gt_anchors_bboxes)
new_bboxes = tf.scatter_nd(indices, updates, shape)
gt_anchors_bboxes = tf.where(new_mask, new_bboxes, gt_anchors_bboxes)
#update gt_anchors_scores
updates = tf.reshape(jaccard[i, max_inds[i]], [-1])
indices = tf.reshape(max_inds[i],[1,-1])
shape = tf.reshape(num_anchors,[-1])
new_scores = tf.scatter_nd(indices, updates, shape)
gt_anchors_scores = tf.where(new_mask, new_scores, gt_anchors_scores)
return [i+1,gt_anchors_labels,gt_anchors_bboxes,gt_anchors_scores]
i = 0
[i,gt_anchor_labels,gt_anchor_bboxes,gt_anchor_scores] = tf.while_loop(cond, body,[i,gt_anchor_labels,gt_anchor_bboxes,gt_anchor_scores])
return gt_anchor_labels,gt_anchor_bboxes,gt_anchor_scores
评论列表
文章目录