def get_detector(net, prefix, epoch, data_shape, mean_pixels, ctx,
nms_thresh=0.5, force_nms=True):
"""
wrapper for initialize a detector
Parameters:
----------
net : str
test network name
prefix : str
load model prefix
epoch : int
load model epoch
data_shape : int
resize image shape
mean_pixels : tuple (float, float, float)
mean pixel values (R, G, B)
ctx : mx.ctx
running context, mx.cpu() or mx.gpu(?)
force_nms : bool
force suppress different categories
"""
sys.path.append(os.path.join(os.getcwd(), 'symbol'))
net = importlib.import_module("symbol_" + net) \
.get_symbol(len(CLASSES), nms_thresh, force_nms)
detector = Detector(net, prefix + "_" + str(data_shape), epoch, \
data_shape, mean_pixels, ctx=ctx)
return detector
评论列表
文章目录