def __init__(self, symbol, model_prefix, epoch, data_hw, mean_pixels,
img_stride=32, th_nms=0.3333, ctx=None):
'''
'''
self.ctx = mx.cpu() if not ctx else ctx
if isinstance(data_hw, int):
data_hw = (data_hw, data_hw)
assert data_hw[0] % img_stride == 0 and data_hw[1] % img_stride == 0
self.data_hw = data_hw
_, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, epoch)
self.mod = mx.mod.Module(symbol, label_names=None, context=ctx)
self.mod.bind(data_shapes=[('data', (1, 3, data_hw[0], data_hw[1]))])
self.mod.set_params(arg_params, aux_params)
self.mean_pixels = mean_pixels
self.img_stride = img_stride
self.th_nms = th_nms
评论列表
文章目录