encoder.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:DSOD-Pytorch-Implementation 作者: Ellinier 项目源码 文件源码
def decode(self, loc, conf):
        '''Transform predicted loc/conf back to real bbox locations and class labels.

        Args:
          loc: (tensor) predicted loc, sized [8732,4].
          conf: (tensor) predicted conf, sized [8732,21].

        Returns:
          boxes: (tensor) bbox locations, sized [#obj, 4].
          labels: (tensor) class labels, sized [#obj,1].
        '''
        variances = self.variances
        wh = torch.exp(loc[:,2:]*variances[1]) * self.default_boxes[:,2:]
        cxcy = loc[:,:2] * variances[0] * self.default_boxes[:,2:] + self.default_boxes[:,:2]
        boxes = torch.cat([cxcy-wh/2, cxcy+wh/2], 1)  # [8732,4]

        max_conf, labels = conf.max(1)  # [8732,1]
        ids = labels.squeeze(1).nonzero()
    if ids.numel() == 0:
        return None, None, None
        ids.squeeze_(1)  # [#boxes,]

        keep = self.nms(boxes[ids], max_conf[ids].squeeze(1), threshold=0.3)
        return boxes[ids][keep], labels[ids][keep]-1, max_conf[ids][keep]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号