multibox_loss.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:DSOD-Pytorch-Implementation 作者: Ellinier 项目源码 文件源码
def forward(self, loc_preds, loc_targets, conf_preds, conf_targets):
        '''Compute loss between (loc_preds, loc_targets) and (conf_preds, conf_targets).

        Args:
          loc_preds: (tensor) predicted locations, sized [batch_size, 8732, 4].
          loc_targets: (tensor) encoded target locations, sized [batch_size, 8732, 4].
          conf_preds: (tensor) predicted class confidences, sized [batch_size, 8732, num_classes].
          conf_targets: (tensor) encoded target classes, sized [batch_size, 8732].

        loss:
          (tensor) loss = SmoothL1Loss(loc_preds, loc_targets) + CrossEntropyLoss(conf_preds, conf_targets).
        '''
        batch_size, num_boxes, _ = loc_preds.size()

        pos = conf_targets>0  # [N,8732], pos means the box matched.
        # print(pos.size())
        num_matched_boxes = pos.data.long().sum()
        if num_matched_boxes == 0:
            return Variable(torch.Tensor([0])), Variable(torch.Tensor([0]))

        ################################################################
        # loc_loss = SmoothL1Loss(pos_loc_preds, pos_loc_targets)
        ################################################################
        pos_mask = pos.unsqueeze(2).expand_as(loc_preds)    # [N,8732,4]
        pos_loc_preds = loc_preds[pos_mask].view(-1,4)      # [#pos,4]
        # print(pos_loc_preds.size())
        pos_loc_targets = loc_targets[pos_mask].view(-1,4)  # [#pos,4]
        loc_loss = F.smooth_l1_loss(pos_loc_preds, pos_loc_targets, size_average=False)

        ################################################################
        # conf_loss = CrossEntropyLoss(pos_conf_preds, pos_conf_targets)
        #           + CrossEntropyLoss(neg_conf_preds, neg_conf_targets)
        ################################################################
        # print('1',conf_preds.size()) # [N, 8732, 16]
        # print('2',conf_targets.size()) # [N, 8732]
        conf_loss = self.cross_entropy_loss(conf_preds.view(-1,self.num_classes), \
                                            conf_targets.view(-1))  # [N*8732,]
        # print(conf_loss.size()) # [8732, 8732]
        conf_loss = conf_loss.view(-1, 8732)
        neg = self.hard_negative_mining(conf_loss, pos)    # [N,8732]

        pos_mask = pos.unsqueeze(2).expand_as(conf_preds)  # [N,8732,21]
        # print(conf_preds.size()) # [N, 8732, 16]
        # print(neg.size()) # [N, 8732*8732]
        neg_mask = neg.unsqueeze(2).expand_as(conf_preds)  # [N,8732,21]
        mask = (pos_mask+neg_mask).gt(0)

        pos_and_neg = (pos+neg).gt(0)
        preds = conf_preds[mask].view(-1,self.num_classes)  # [#pos+#neg,21]
        targets = conf_targets[pos_and_neg]                 # [#pos+#neg,]
        conf_loss = F.cross_entropy(preds, targets, size_average=False)

        loc_loss /= num_matched_boxes
        conf_loss /= num_matched_boxes
        # print('loc_loss: %f conf_loss: %f' % (loc_loss.data[0], conf_loss.data[0]), end=' ')
        return loc_loss, conf_loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号