multibox_loss.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:DSOD-Pytorch-Implementation 作者: Ellinier 项目源码 文件源码
def hard_negative_mining(self, conf_loss, pos):
        '''Return negative indices that is 3x the number as postive indices.

        Args:
          conf_loss: (tensor) cross entroy loss between conf_preds and conf_targets, sized [N*8732,].
          pos: (tensor) positive(matched) box indices, sized [N,8732].

        Return:
          (tensor) negative indices, sized [N,8732].
        '''
        batch_size, num_boxes = pos.size()
        # print(pos)
        # print(conf_loss.size())

        conf_loss[pos] = 0  # set pos boxes = 0, the rest are neg conf_loss
        conf_loss = conf_loss.view(batch_size, -1)  # [N,8732]

        _,idx = conf_loss.sort(1, descending=True)  # sort by neg conf_loss
        _,rank = idx.sort(1)  # [N,8732]

        num_pos = pos.long().sum(1)  # [N,1]
        num_neg = torch.clamp(3*num_pos, max=num_boxes-1)  # [N,1]

        neg = rank < num_neg.view(-1, 1).expand_as(rank)  # [N,8732]
        return neg
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号