def get_symbol(num_classes=20, nms_thresh=0.5, force_suppress=False,
nms_topk=400, **kwargs):
"""
Single-shot multi-box detection with VGG 16 layers ConvNet
This is a modified version, with fc6/fc7 layers replaced by conv layers
And the network is slightly smaller than original VGG 16 network
This is the detection network
Parameters:
----------
num_classes: int
number of object classes not including background
nms_thresh : float
threshold of overlap for non-maximum suppression
force_suppress : boolean
whether suppress different class objects
nms_topk : int
apply NMS to top K detections
Returns:
----------
mx.Symbol
"""
net = get_symbol_train(num_classes)
cls_preds = net.get_internals()["multibox_cls_pred_output"]
loc_preds = net.get_internals()["multibox_loc_pred_output"]
anchor_boxes = net.get_internals()["multibox_anchors_output"]
cls_prob = mx.symbol.SoftmaxActivation(data=cls_preds, mode='channel', \
name='cls_prob')
out = mx.contrib.symbol.MultiBoxDetection(*[cls_prob, loc_preds, anchor_boxes], \
name="detection", nms_threshold=nms_thresh, force_suppress=force_suppress,
variances=(0.1, 0.1, 0.2, 0.2), nms_topk=nms_topk)
return out
评论列表
文章目录