def s1_predict(config_file, model_dir, model_file, predict_file_list, out_dir):
"""
This function serves as a test/validation tool during the model development. It is not used as
a final product in part of the pipeline.
"""
with open(config_file) as config_buffer:
config = json.loads(config_buffer.read())
with tf.Graph().as_default() as graph:
converted_model = ConvertedModel(config, graph, 's1_keras', model_dir, model_file)
with tf.Session(graph=graph) as sess:
for img_file in predict_file_list:
image = cv2.imread(img_file)
boxes = converted_model.predict(sess, image)
image = draw_boxes(image, boxes)
_, filename = os.path.split(img_file)
cv2.imwrite(os.path.join(out_dir, filename), image)
评论列表
文章目录