detect_camera.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:yolo-tf 作者: ruiminshen 项目源码 文件源码
def main():
    model = config.get('config', 'model')
    yolo = importlib.import_module('model.' + model)
    width = config.getint(model, 'width')
    height = config.getint(model, 'height')
    preprocess = getattr(importlib.import_module('detect'), args.preprocess)
    with tf.Session() as sess:
        ph_image = tf.placeholder(tf.float32, [1, height, width, 3], name='ph_image')
        builder = yolo.Builder(args, config)
        builder(ph_image)
        global_step = tf.contrib.framework.get_or_create_global_step()
        model_path = tf.train.latest_checkpoint(utils.get_logdir(config))
        tf.logging.info('load ' + model_path)
        slim.assign_from_checkpoint_fn(model_path, tf.global_variables())(sess)
        tf.logging.info('global_step=%d' % sess.run(global_step))
        tensors = [builder.model.conf, builder.model.xy_min, builder.model.xy_max]
        tensors = [tf.check_numerics(t, t.op.name) for t in tensors]
        cap = cv2.VideoCapture(0)
        try:
            while True:
                ret, image_bgr = cap.read()
                assert ret
                image_height, image_width, _ = image_bgr.shape
                scale = [image_width / builder.model.cell_width, image_height / builder.model.cell_height]
                image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
                image_std = np.expand_dims(preprocess(cv2.resize(image_rgb, (width, height))).astype(np.float32), 0)
                feed_dict = {ph_image: image_std}
                conf, xy_min, xy_max = sess.run(tensors, feed_dict)
                boxes = utils.postprocess.non_max_suppress(conf[0], xy_min[0], xy_max[0], args.threshold, args.threshold_iou)
                for _conf, _xy_min, _xy_max in boxes:
                    index = np.argmax(_conf)
                    if _conf[index] > args.threshold:
                        _xy_min = (_xy_min * scale).astype(np.int)
                        _xy_max = (_xy_max * scale).astype(np.int)
                        cv2.rectangle(image_bgr, tuple(_xy_min), tuple(_xy_max), (255, 0, 255), 3)
                        cv2.putText(image_bgr, builder.names[index] + ' (%.1f%%)' % (_conf[index] * 100), tuple(_xy_min), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
                cv2.imshow('detection', image_bgr)
                cv2.waitKey(1)
        finally:
            cv2.destroyAllWindows()
            cap.release()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号