roi.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:dataset 作者: analysiscenter 项目源码 文件源码
def non_max_suppression(inputs, scores, batch_size, max_output_size,
                        score_threshold=0.7, iou_threshold=0.7, nonempty=False, name='nms'):
    """ Perform NMS on batch of images.

    Parameters
    ----------
        inputs: tf.Tuple
            each components is a set of bboxes for corresponding image
        scores: tf.Tuple
            scores of inputs
        batch_size:
            size of batch of inputs
        max_output_size:
            maximal size of bboxes per image
        score_threshold: float
            bboxes with score less the score_threshold will be dropped
        iou_threshold: float
            bboxes with iou which is greater then iou_threshold will be merged
        nonempty: bool
            if True at least one bbox per image will be returned
        name: str
            scope name

    Returns
    -------
        tf.Tuple
            indices of selected bboxes for each image

    """
    with tf.variable_scope(name):
        ix = tf.constant(0)
        filtered_rois = tf.TensorArray(dtype=tf.int32, size=batch_size, infer_shape=False)
        loop_cond = lambda ix, filtered_rois: tf.less(ix, batch_size)
        def _loop_body(ix, filtered_rois):
            indices, score, roi = _filter_tensor(scores[ix], score_threshold, inputs[ix]) # pylint: disable=unbalanced-tuple-unpacking
            roi_corners = tf.concat([roi[:, :2], roi[:, :2]+roi[:, 2:]], axis=-1)
            roi_after_nms = tf.image.non_max_suppression(roi_corners, score, max_output_size, iou_threshold)
            if nonempty:
                is_not_empty = lambda: filtered_rois.write(ix,
                                                           tf.cast(tf.gather(indices, roi_after_nms),
                                                                   dtype=tf.int32))
                is_empty = lambda: filtered_rois.write(ix, tf.constant([[0]]))
                filtered_rois = tf.cond(tf.not_equal(tf.shape(indices)[0], 0), is_not_empty, is_empty)
            else:
                filtered_rois = filtered_rois.write(ix, tf.cast(tf.gather(indices, roi_after_nms), dtype=tf.int32))
            return [ix+1, filtered_rois]
        _, res = tf.while_loop(loop_cond, _loop_body, [ix, filtered_rois])
        res = _array_to_tuple(res, batch_size, [-1, 1])
    return res
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号