tensorlfowapi.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:SSD_tensorflow_VOC 作者: LevinJ 项目源码 文件源码
def test_scatter_nd_3():
    gt_bboxes = tf.constant([[0,0,1,2],[1,0,3,4],[100,100,105,102.5]])
    gt_labels = tf.constant([1,2,6])

    jaccard = tf.constant( [[ 0. ,     0.  ,    0.02,    0.15  ],[ 0. ,     0.3125 , 0.08,    0.    ],[ 0.5 ,    0. ,     0.  ,    0.    ]])
    gt_anchors_scores = tf.constant([0.0,0.,0.,0.])
    gt_anchors_labels = tf.constant([100,100,100,100])
    gt_anchors_bboxes=tf.constant([[100,100,105,105],[2,1,3,3.5],[0,0,10,10],[0.5,0.5,0.8,1.5]])

    max_inds = tf.cast(tf.argmax(jaccard, axis=1),tf.int32)

    def cond(i,gt_anchors_labels,gt_anchors_bboxes,gt_anchors_scores):
        r = tf.less(i, tf.shape(gt_labels)[0])
        return r
    def body(i,gt_anchors_labels,gt_anchors_bboxes,gt_anchors_scores):

        #upate gt_anchors_labels
        updates = tf.reshape(gt_labels[i], [-1])
        indices = tf.reshape(max_inds[i],[1,-1])
        shape = tf.reshape(tf.shape(gt_anchors_bboxes)[0],[-1])


        new_labels = tf.scatter_nd(indices, updates, shape)
        new_mask = tf.cast(new_labels, tf.bool)
        gt_anchors_labels = tf.where(new_mask, new_labels, gt_anchors_labels)

        #update gt_anchors_bboxes
        updates = tf.reshape(gt_bboxes[i], [1,-1])
        indices = tf.reshape(max_inds[i],[1,-1])
        shape = tf.shape(gt_anchors_bboxes)
        new_bboxes = tf.scatter_nd(indices, updates, shape)
        gt_anchors_bboxes = tf.where(new_mask, new_bboxes, gt_anchors_bboxes)

        #update gt_anchors_scores
        updates = tf.reshape(jaccard[i, max_inds[i]], [-1])
        indices = tf.reshape(max_inds[i],[1,-1])
        shape = tf.reshape(tf.shape(gt_anchors_bboxes)[0],[-1])
        new_scores = tf.scatter_nd(indices, updates, shape)
        gt_anchors_scores = tf.where(new_mask, new_scores, gt_anchors_scores)



        return [i+1,gt_anchors_labels,gt_anchors_bboxes,gt_anchors_scores]


    i = 0
    [i,gt_anchors_labels,gt_anchors_bboxes,gt_anchors_scores] = tf.while_loop(cond, body,[i,gt_anchors_labels,gt_anchors_bboxes,gt_anchors_scores])





    return gt_anchors_labels,gt_anchors_bboxes,gt_anchors_scores
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号