def get_runner(path):
param_dict = np.load(path, encoding='latin1').item()
predict_func = OfflinePredictor(PredictConfig(
model=Model(),
session_init=ParamRestore(param_dict),
session_config=get_default_sess_config(0.99),
input_names=['input'],
#output_names=['Mconv7_stage6/output']
output_names=['resized_map']
))
def func_single(img):
# img is bgr, [0,255]
# return the output in WxHx15
return predict_func([[img]])[0][0]
def func_batch(imgs):
# img is bgr, [0,255], nhwc
# return the output in nhwc
return predict_func([imgs])[0]
return func_single, func_batch
评论列表
文章目录