def test_scatter_nd_3():
gt_bboxes = tf.constant([[0,0,1,2],[1,0,3,4],[100,100,105,102.5]])
gt_labels = tf.constant([1,2,6])
jaccard = tf.constant( [[ 0. , 0. , 0.02, 0.15 ],[ 0. , 0.3125 , 0.08, 0. ],[ 0.5 , 0. , 0. , 0. ]])
gt_anchors_scores = tf.constant([0.0,0.,0.,0.])
gt_anchors_labels = tf.constant([100,100,100,100])
gt_anchors_bboxes=tf.constant([[100,100,105,105],[2,1,3,3.5],[0,0,10,10],[0.5,0.5,0.8,1.5]])
max_inds = tf.cast(tf.argmax(jaccard, axis=1),tf.int32)
def cond(i,gt_anchors_labels,gt_anchors_bboxes,gt_anchors_scores):
r = tf.less(i, tf.shape(gt_labels)[0])
return r
def body(i,gt_anchors_labels,gt_anchors_bboxes,gt_anchors_scores):
#upate gt_anchors_labels
updates = tf.reshape(gt_labels[i], [-1])
indices = tf.reshape(max_inds[i],[1,-1])
shape = tf.reshape(tf.shape(gt_anchors_bboxes)[0],[-1])
new_labels = tf.scatter_nd(indices, updates, shape)
new_mask = tf.cast(new_labels, tf.bool)
gt_anchors_labels = tf.where(new_mask, new_labels, gt_anchors_labels)
#update gt_anchors_bboxes
updates = tf.reshape(gt_bboxes[i], [1,-1])
indices = tf.reshape(max_inds[i],[1,-1])
shape = tf.shape(gt_anchors_bboxes)
new_bboxes = tf.scatter_nd(indices, updates, shape)
gt_anchors_bboxes = tf.where(new_mask, new_bboxes, gt_anchors_bboxes)
#update gt_anchors_scores
updates = tf.reshape(jaccard[i, max_inds[i]], [-1])
indices = tf.reshape(max_inds[i],[1,-1])
shape = tf.reshape(tf.shape(gt_anchors_bboxes)[0],[-1])
new_scores = tf.scatter_nd(indices, updates, shape)
gt_anchors_scores = tf.where(new_mask, new_scores, gt_anchors_scores)
return [i+1,gt_anchors_labels,gt_anchors_bboxes,gt_anchors_scores]
i = 0
[i,gt_anchors_labels,gt_anchors_bboxes,gt_anchors_scores] = tf.while_loop(cond, body,[i,gt_anchors_labels,gt_anchors_bboxes,gt_anchors_scores])
return gt_anchors_labels,gt_anchors_bboxes,gt_anchors_scores
评论列表
文章目录