inferrable_test.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:num-seq-recognizer 作者: gmlove 项目源码 文件源码
def test_combined_infer(self):
    from nsrec.nets import iclr_mnr, lenet_v2
    from six.moves import cPickle as pickle

    metadata = pickle.loads(open(test_helper.train_data_dir_path + '/metadata.pickle', 'rb').read())

    def test_img_data_generator(new_size, crop_bbox=False):
      for i in range(10):
        filename = '%s.png' % (i + 1)
        img_idx = metadata['filenames'].index(filename)
        bbox, label = metadata['bboxes'][img_idx], metadata['labels'][img_idx]
        input_data = inputs.read_img(os.path.join(test_helper.train_data_dir_path, filename))
        width, height = input_data.shape[1], input_data.shape[0]
        if crop_bbox:
          input_data = inputs.read_img(os.path.join(test_helper.train_data_dir_path, filename), bbox)
        input_data = inputs.normalize_img(input_data, [new_size[0], new_size[1]])
        yield (input_data, (width, height), bbox, label)

    bbox_model = Inferrable(test_helper.output_bbox_graph_file, 'initializer-bbox', 'input-bbox', 'output-bbox')
    for input_data, (width, height), bbox, _ in test_img_data_generator([lenet_v2.image_width, lenet_v2.image_height]):
      bbox_in_rate = bbox_model.infer(np.array([input_data]))
      print(width, height)
      print('label bbox: %s, bbox: %s' % (bbox, [bbox_in_rate[0] * width, bbox_in_rate[1] * height,
                                                 bbox_in_rate[2] * width, bbox_in_rate[3] * height]))

    nsr_model = Inferrable(test_helper.output_graph_file, 'initializer', 'input', 'output')
    for input_data, _, _, label in test_img_data_generator([iclr_mnr.image_width, iclr_mnr.image_height], True):
      pb = nsr_model.infer(np.array([input_data]))
      print('actual: %s, length pb: %s, numbers: %s' % (
        label, np.argmax(pb[:5]), np.argmax(pb[5:].reshape([5, 11]), axis=1)))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号