inference.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:ssd_tensorflow 作者: railsnoob 项目源码 文件源码
def predict_boxes(self,model_name="trained-model"):
        """ Given a directory containing the dataset and images_path show objects detected for
        a random image.

        Args    - dirname: Name of directory containing meta data of training run.
        Returns - Nothing. Shows the pic with detections.

        """ 
        images_path            = self.cfg.g("images_path")

        train_imgs             = pickle.load(open(self.dirname+"/train.pkl","rb"))

        image_info             = random.choice(list(train_imgs))
        image_name             = image_info['img_name']
        p_conf, p_loc, p_probs = self.run_inference(image_name,model_name)
        non_zero_indices       = np.where(p_conf > 0)[1]

        # DEBUGGING
        print("p_conf={} p_loc={} p_probs={}".format(p_conf.shape,p_loc.shape,p_probs.shape))
        print("Non zero indices",non_zero_indices)
        for i in non_zero_indices:
            print(i,") location",p_loc[0][i*4:i*4+4],"probs", p_probs[0][i],"conf",p_conf[0][i])

        boxes, confs = self.convert_coordinates_to_boxes(p_loc,p_conf,p_probs)
        print("Boxes BEFORE NMS")
        for i,a in enumerate(zip(boxes,confs)):
            print(i,a)

        boxes = non_max_suppression_fast(boxes,0.3)

        print("Boxes AFTER NMS")
        print(boxes)

        img   = mpimg.imread(images_path+"/"+image_name)

        self.debug_draw_boxes(img,boxes,(0,255,0),2)

        plt.figure(figsize=(8,8))
        plt.imshow(img)
        plt.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号