def decode(self, loc, conf):
'''Transform predicted loc/conf back to real bbox locations and class labels.
Args:
loc: (tensor) predicted loc, sized [8732,4].
conf: (tensor) predicted conf, sized [8732,21].
Returns:
boxes: (tensor) bbox locations, sized [#obj, 4].
labels: (tensor) class labels, sized [#obj,1].
'''
variances = self.variances
wh = torch.exp(loc[:,2:]*variances[1]) * self.default_boxes[:,2:]
cxcy = loc[:,:2] * variances[0] * self.default_boxes[:,2:] + self.default_boxes[:,:2]
boxes = torch.cat([cxcy-wh/2, cxcy+wh/2], 1) # [8732,4]
max_conf, labels = conf.max(1) # [8732,1]
ids = labels.squeeze(1).nonzero()
if ids.numel() == 0:
return None, None, None
ids.squeeze_(1) # [#boxes,]
keep = self.nms(boxes[ids], max_conf[ids].squeeze(1), threshold=0.3)
return boxes[ids][keep], labels[ids][keep]-1, max_conf[ids][keep]
评论列表
文章目录