evaluate_combined.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:ObjectDetection 作者: PhilippParis 项目源码 文件源码
def main(_):
    image_path = FLAGS.test
    csv_path = os.path.splitext(image_path)[0] + ".csv"

    # --------- load classifier ------- #
    cascade = cv2.CascadeClassifier(FLAGS.cascade_xml)
    model, x, keep_prob = get_nn_classifier()

    # ---------- object detection ------------#    
    print 'starting detection of ' + FLAGS.test + '...'

    img = utils.getImage(image_path)
    img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)

    delta = [-2, -1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95, 0.99, 0.995, 0.999, 0.9995, 0.9999]

    start = time.time()
    candidates = cascade.detectMultiScale(img, scaleFactor=FLAGS.scaleFactor, minNeighbors=FLAGS.minNeighbors, maxSize=(FLAGS.max_window_size,FLAGS.max_window_size))
    detected = nn_classification(candidates, img, model, x, keep_prob, delta)
    elapsed = (time.time() - start)  

    print 'detection time: %d' % elapsed

    # ------------- evaluation --------------#

    ground_truth_data = utils.get_ground_truth_data(csv_path)

    for j in xrange(0, len(delta)):
        detected[j] = [Rect(x, y, w, h) for (x,y,w,h) in detected[j]]
        tp, fn, fp = utils.evaluate(ground_truth_data, detected[j])

        # ----------------output ----------------#
        # image output
        """
        img_out = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
        for (x,y,w,h) in detected[j]:
            cv2.rectangle(img_out, (x-w/2,y-h/2),(x+w/2,y+h/2), [0,255,0], 3)

        for c in ground_truth_data:
            cv2.circle(img_out, (c[0], c[1]), 3, [0,0,255],3)

        output_file = "out" + '_' + str(datetime.datetime.now())
        cv2.imwrite(FLAGS.output_dir + output_file + '.png', img_out)
        """
        # csv output
        with open(FLAGS.output_dir + FLAGS.out + '.csv', 'ab') as file:
            writer = csv.writer(file, delimiter=',')
            writer.writerow([FLAGS.test, str(elapsed), str(len(ground_truth_data)), delta[j], FLAGS.minNeighbors, FLAGS.scaleFactor, 
                            str(len(detected[j])), str(tp), str(fp), str(fn)])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号