def detect(sess, model, names, image, path):
preprocess = eval(args.preprocess)
_, height, width, _ = image.get_shape().as_list()
_image = read_image(path)
image_original = np.array(np.uint8(_image))
if len(image_original.shape) == 2:
image_original = np.repeat(np.expand_dims(image_original, -1), 3, 2)
image_height, image_width, _ = image_original.shape
image_std = preprocess(np.array(np.uint8(_image.resize((width, height)))).astype(np.float32))
feed_dict = {image: np.expand_dims(image_std, 0)}
tensors = [model.conf, model.xy_min, model.xy_max]
conf, xy_min, xy_max = sess.run([tf.check_numerics(t, t.op.name) for t in tensors], feed_dict=feed_dict)
boxes = utils.postprocess.non_max_suppress(conf[0], xy_min[0], xy_max[0], args.threshold, args.threshold_iou)
scale = [image_width / model.cell_width, image_height / model.cell_height]
fig = plt.figure()
ax = fig.gca()
ax.imshow(image_original)
colors = [prop['color'] for _, prop in zip(names, itertools.cycle(plt.rcParams['axes.prop_cycle']))]
cnt = 0
for _conf, _xy_min, _xy_max in boxes:
index = np.argmax(_conf)
if _conf[index] > args.threshold:
wh = _xy_max - _xy_min
_xy_min = _xy_min * scale
_wh = wh * scale
linewidth = min(_conf[index] * 10, 3)
ax.add_patch(patches.Rectangle(_xy_min, _wh[0], _wh[1], linewidth=linewidth, edgecolor=colors[index], facecolor='none'))
ax.annotate(names[index] + ' (%.1f%%)' % (_conf[index] * 100), _xy_min, color=colors[index])
cnt += 1
fig.canvas.set_window_title('%d objects detected' % cnt)
ax.set_xticks([])
ax.set_yticks([])
return fig
评论列表
文章目录