def __init__(self, symbol, model_prefix, epoch, data_hw, mean_pixels,
img_stride=32, th_nms=0.3333, ctx=None):
'''
'''
self.ctx = mx.cpu() if not ctx else ctx
if isinstance(data_hw, int):
data_hw = (data_hw, data_hw)
assert data_hw[0] % img_stride == 0 and data_hw[1] % img_stride == 0
self.data_hw = data_hw
_, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, epoch)
self.mod = mx.mod.Module(symbol, label_names=None, context=ctx)
self.mod.bind(data_shapes=[('data', (1, 3, data_hw[0], data_hw[1]))])
self.mod.set_params(arg_params, aux_params)
self.mean_pixels = mean_pixels
self.img_stride = img_stride
self.th_nms = th_nms
python类ctx()的实例源码
def __init__(self, symbol, model_prefix, epoch, data_shape, mean_pixels, \
batch_size=1, ctx=None):
self.ctx = ctx
if self.ctx is None:
self.ctx = mx.cpu()
load_symbol, args, auxs = mx.model.load_checkpoint(model_prefix, epoch)
if symbol is None:
symbol = load_symbol
self.mod = mx.mod.Module(symbol, label_names=None, context=ctx)
self.data_shape = data_shape
self.mod.bind(data_shapes=[('data', (batch_size, 3, data_shape, data_shape))])
self.mod.set_params(args, auxs)
self.data_shape = data_shape
self.mean_pixels = mean_pixels
def __init__(self, symbol, model_prefix, epoch, data_shape, mean_pixels, \
batch_size=1, ctx=None):
self.ctx = ctx
if self.ctx is None:
self.ctx = mx.cpu()
load_symbol, args, auxs = mx.model.load_checkpoint(model_prefix, epoch)
if symbol is None:
symbol = load_symbol
self.mod = mx.mod.Module(symbol, label_names=None, context=ctx)
self.data_shape = data_shape
self.mod.bind(data_shapes=[('data', (batch_size, 3, data_shape, data_shape))])
self.mod.set_params(args, auxs)
self.data_shape = data_shape
self.mean_pixels = mean_pixels
def get_detector(net, prefix, epoch, data_shape, mean_pixels, ctx, num_class,
nms_thresh=0.5, force_nms=True, nms_topk=400):
"""
wrapper for initialize a detector
Parameters:
----------
net : str
test network name
prefix : str
load model prefix
epoch : int
load model epoch
data_shape : int
resize image shape
mean_pixels : tuple (float, float, float)
mean pixel values (R, G, B)
ctx : mx.ctx
running context, mx.cpu() or mx.gpu(?)
num_class : int
number of classes
nms_thresh : float
non-maximum suppression threshold
force_nms : bool
force suppress different categories
"""
if net is not None:
net = get_symbol(net, data_shape, num_classes=num_class, nms_thresh=nms_thresh,
force_nms=force_nms, nms_topk=nms_topk)
detector = Detector(net, prefix, epoch, data_shape, mean_pixels, ctx=ctx)
return detector
def __init__(self, symbol, model_prefix, epoch, data_shape, mean_pixels, \
batch_size=1, ctx=None):
self.ctx = ctx
if self.ctx is None:
self.ctx = mx.cpu()
_, args, auxs = mx.model.load_checkpoint(model_prefix, epoch)
self.mod = mx.mod.Module(symbol, context=ctx)
self.data_shape = data_shape
self.mod.bind(data_shapes=[('data', (batch_size, 3, data_shape, data_shape))])
self.mod.set_params(args, auxs)
self.data_shape = data_shape
self.mean_pixels = mean_pixels
def get_detector(net, prefix, epoch, data_shape, mean_pixels, ctx,
nms_thresh=0.5, force_nms=True):
"""
wrapper for initialize a detector
Parameters:
----------
net : str
test network name
prefix : str
load model prefix
epoch : int
load model epoch
data_shape : int
resize image shape
mean_pixels : tuple (float, float, float)
mean pixel values (R, G, B)
ctx : mx.ctx
running context, mx.cpu() or mx.gpu(?)
force_nms : bool
force suppress different categories
"""
sys.path.append(os.path.join(os.getcwd(), 'symbol'))
net = importlib.import_module("symbol_" + net) \
.get_symbol(len(CLASSES), nms_thresh, force_nms)
detector = Detector(net, prefix + "_" + str(data_shape), epoch, \
data_shape, mean_pixels, ctx=ctx)
return detector
def __init__(self, symbol, model_prefix, epoch, data_shape, mean_pixels, \
batch_size=1, ctx=None):
self.ctx = ctx
if self.ctx is None:
self.ctx = mx.cpu()
load_symbol, args, auxs = mx.model.load_checkpoint(model_prefix, epoch)
if symbol is None:
symbol = load_symbol
self.mod = mx.mod.Module(symbol, label_names=None, context=ctx)
self.data_shape = data_shape
self.mod.bind(data_shapes=[('data', (batch_size, 3, data_shape, data_shape))])
self.mod.set_params(args, auxs)
self.data_shape = data_shape
self.mean_pixels = mean_pixels
self.th_nms = cfg.valid['th_nms']
def get_detector(net, prefix, epoch, data_shape, mean_pixels, ctx, num_class,
nms_thresh=0.5, force_nms=True, nms_topk=400):
"""
wrapper for initialize a detector
Parameters:
----------
net : str
test network name
prefix : str
load model prefix
epoch : int
load model epoch
data_shape : int
resize image shape
mean_pixels : tuple (float, float, float)
mean pixel values (R, G, B)
ctx : mx.ctx
running context, mx.cpu() or mx.gpu(?)
num_class : int
number of classes
nms_thresh : float
non-maximum suppression threshold
force_nms : bool
force suppress different categories
"""
if net is not None:
net = get_symbol(net, data_shape, num_classes=num_class, nms_thresh=nms_thresh,
force_nms=force_nms, nms_topk=nms_topk)
detector = FaceDetector(net, prefix, epoch, data_shape, mean_pixels, ctx=ctx)
return detector
def get_detector(net, prefix, epoch, data_shape, mean_pixels, ctx, num_class,
nms_thresh=0.5, force_nms=True, nms_topk=400):
"""
wrapper for initialize a detector
Parameters:
----------
net : str
test network name
prefix : str
load model prefix
epoch : int
load model epoch
data_shape : int
resize image shape
mean_pixels : tuple (float, float, float)
mean pixel values (R, G, B)
ctx : mx.ctx
running context, mx.cpu() or mx.gpu(?)
num_class : int
number of classes
nms_thresh : float
non-maximum suppression threshold
force_nms : bool
force suppress different categories
"""
if net is not None:
net = get_symbol(net, data_shape, num_classes=num_class, nms_thresh=nms_thresh,
force_nms=force_nms, nms_topk=nms_topk)
_, _ = estimate_mac(net, data_shape=(1, 3, data_shape, data_shape))
detector = Detector(net, prefix, epoch, data_shape, mean_pixels, ctx=ctx)
return detector
def __init__(self, symbol, model_prefix, epoch, data_shape, mean_pixels, \
batch_size=1, ctx=None):
self.ctx = ctx
if self.ctx is None:
self.ctx = mx.cpu()
load_symbol, args, auxs = mx.model.load_checkpoint(model_prefix, epoch)
if symbol is None:
symbol = load_symbol
self.mod = mx.mod.Module(symbol, label_names=("yolo_output_label",), context=ctx)
self.data_shape = data_shape
self.mod.bind(data_shapes=[('data', (batch_size, 3, data_shape, data_shape))],
label_shapes=[('yolo_output_label', (batch_size, 2, 5))])
self.mod.set_params(args, auxs)
self.data_shape = data_shape
self.mean_pixels = mean_pixels
def get_detector(net, prefix, epoch, data_shape, mean_pixels, ctx,
nms_thresh=0.5, force_nms=True):
"""
wrapper for initialize a detector
Parameters:
----------
net : str
test network name
prefix : str
load model prefix
epoch : int
load model epoch
data_shape : int
resize image shape
mean_pixels : tuple (float, float, float)
mean pixel values (R, G, B)
ctx : mx.ctx
running context, mx.cpu() or mx.gpu(?)
force_nms : bool
force suppress different categories
"""
sys.path.append(os.path.join(os.getcwd(), 'symbol'))
if net is not None:
prefix = prefix + "_" + net.strip('_yolo') + '_' + str(416)
net = importlib.import_module("symbol_" + net) \
.get_symbol(len(CLASSES), nms_thresh, force_nms)
detector = Detector(net, prefix, epoch, \
data_shape, mean_pixels, ctx=ctx)
return detector