def test_combined_infer(self):
from nsrec.nets import iclr_mnr, lenet_v2
from six.moves import cPickle as pickle
metadata = pickle.loads(open(test_helper.train_data_dir_path + '/metadata.pickle', 'rb').read())
def test_img_data_generator(new_size, crop_bbox=False):
for i in range(10):
filename = '%s.png' % (i + 1)
img_idx = metadata['filenames'].index(filename)
bbox, label = metadata['bboxes'][img_idx], metadata['labels'][img_idx]
input_data = inputs.read_img(os.path.join(test_helper.train_data_dir_path, filename))
width, height = input_data.shape[1], input_data.shape[0]
if crop_bbox:
input_data = inputs.read_img(os.path.join(test_helper.train_data_dir_path, filename), bbox)
input_data = inputs.normalize_img(input_data, [new_size[0], new_size[1]])
yield (input_data, (width, height), bbox, label)
bbox_model = Inferrable(test_helper.output_bbox_graph_file, 'initializer-bbox', 'input-bbox', 'output-bbox')
for input_data, (width, height), bbox, _ in test_img_data_generator([lenet_v2.image_width, lenet_v2.image_height]):
bbox_in_rate = bbox_model.infer(np.array([input_data]))
print(width, height)
print('label bbox: %s, bbox: %s' % (bbox, [bbox_in_rate[0] * width, bbox_in_rate[1] * height,
bbox_in_rate[2] * width, bbox_in_rate[3] * height]))
nsr_model = Inferrable(test_helper.output_graph_file, 'initializer', 'input', 'output')
for input_data, _, _, label in test_img_data_generator([iclr_mnr.image_width, iclr_mnr.image_height], True):
pb = nsr_model.infer(np.array([input_data]))
print('actual: %s, length pb: %s, numbers: %s' % (
label, np.argmax(pb[:5]), np.argmax(pb[5:].reshape([5, 11]), axis=1)))
评论列表
文章目录