detect.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:yolo-tf 作者: ruiminshen 项目源码 文件源码
def detect(sess, model, names, image, path):
    preprocess = eval(args.preprocess)
    _, height, width, _ = image.get_shape().as_list()
    _image = read_image(path)
    image_original = np.array(np.uint8(_image))
    if len(image_original.shape) == 2:
        image_original = np.repeat(np.expand_dims(image_original, -1), 3, 2)
    image_height, image_width, _ = image_original.shape
    image_std = preprocess(np.array(np.uint8(_image.resize((width, height)))).astype(np.float32))
    feed_dict = {image: np.expand_dims(image_std, 0)}
    tensors = [model.conf, model.xy_min, model.xy_max]
    conf, xy_min, xy_max = sess.run([tf.check_numerics(t, t.op.name) for t in tensors], feed_dict=feed_dict)
    boxes = utils.postprocess.non_max_suppress(conf[0], xy_min[0], xy_max[0], args.threshold, args.threshold_iou)
    scale = [image_width / model.cell_width, image_height / model.cell_height]
    fig = plt.figure()
    ax = fig.gca()
    ax.imshow(image_original)
    colors = [prop['color'] for _, prop in zip(names, itertools.cycle(plt.rcParams['axes.prop_cycle']))]
    cnt = 0
    for _conf, _xy_min, _xy_max in boxes:
        index = np.argmax(_conf)
        if _conf[index] > args.threshold:
            wh = _xy_max - _xy_min
            _xy_min = _xy_min * scale
            _wh = wh * scale
            linewidth = min(_conf[index] * 10, 3)
            ax.add_patch(patches.Rectangle(_xy_min, _wh[0], _wh[1], linewidth=linewidth, edgecolor=colors[index], facecolor='none'))
            ax.annotate(names[index] + ' (%.1f%%)' % (_conf[index] * 100), _xy_min, color=colors[index])
            cnt += 1
    fig.canvas.set_window_title('%d objects detected' % cnt)
    ax.set_xticks([])
    ax.set_yticks([])
    return fig
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号