def test_scatter_nd_2():
gt_bboxes = tf.constant([[0,0,1,2],[1,0,3,4],[100,100,105,102.5]])
gt_labels = tf.constant([1,2,6])
gt_anchors_labels = tf.Variable([100,100,100,100], trainable=False,collections=[ops.GraphKeys.LOCAL_VARIABLES])
gt_anchors_bboxes=tf.Variable([[100,100,105,105],[2,1,3,3.5],[0,0,10,10],[0.5,0.5,0.8,1.5]], trainable=False,collections=[ops.GraphKeys.LOCAL_VARIABLES],dtype=tf.float32)
max_inds = [1,0,3]
gt_anchors_labels = tf.scatter_update(gt_anchors_labels, max_inds,gt_labels)
gt_anchors_bboxes = tf.scatter_update(gt_anchors_bboxes, max_inds,gt_bboxes)
return gt_anchors_labels,gt_anchors_bboxes
评论列表
文章目录