RPN.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:MSDN 作者: yikang-li 项目源码 文件源码
def build_loss(self, rpn_cls_score_reshape, rpn_bbox_pred, rpn_data, is_region=False):
        # classification loss
        rpn_cls_score = rpn_cls_score_reshape.permute(0, 2, 3, 1).contiguous().view(-1, 2)
        rpn_label = rpn_data[0]

        # print rpn_label.size(), rpn_cls_score.size()

        rpn_keep = Variable(rpn_label.data.ne(-1).nonzero().squeeze()).cuda()
        rpn_cls_score = torch.index_select(rpn_cls_score, 0, rpn_keep)
        rpn_label = torch.index_select(rpn_label, 0, rpn_keep)

        fg_cnt = torch.sum(rpn_label.data.ne(0))
        bg_cnt = rpn_label.data.numel() - fg_cnt
        # ce_weights = torch.ones(rpn_cls_score.size()[1])
        # ce_weights[0] = float(fg_cnt) / bg_cnt
        # ce_weights = ce_weights.cuda()

        _, predict = torch.max(rpn_cls_score.data, 1)
        error = torch.sum(torch.abs(predict - rpn_label.data))
        #  try:
        if predict.size()[0] < 256:
            print predict.size()
            print rpn_label.size()
            print fg_cnt

        if is_region:
            self.tp_region = torch.sum(predict[:fg_cnt].eq(rpn_label.data[:fg_cnt]))
            self.tf_region = torch.sum(predict[fg_cnt:].eq(rpn_label.data[fg_cnt:]))
            self.fg_cnt_region = fg_cnt
            self.bg_cnt_region = bg_cnt
            if DEBUG:
                print 'accuracy: %2.2f%%' % ((self.tp + self.tf) / float(fg_cnt + bg_cnt) * 100)
        else:
            self.tp = torch.sum(predict[:fg_cnt].eq(rpn_label.data[:fg_cnt]))
            self.tf = torch.sum(predict[fg_cnt:].eq(rpn_label.data[fg_cnt:]))
            self.fg_cnt = fg_cnt
            self.bg_cnt = bg_cnt
            if DEBUG:
                print 'accuracy: %2.2f%%' % ((self.tp + self.tf) / float(fg_cnt + bg_cnt) * 100)

        rpn_cross_entropy = F.cross_entropy(rpn_cls_score, rpn_label)
        # print rpn_cross_entropy

        # box loss
        rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights = rpn_data[1:]
        rpn_bbox_targets = torch.mul(rpn_bbox_targets, rpn_bbox_inside_weights)
        rpn_bbox_pred = torch.mul(rpn_bbox_pred, rpn_bbox_inside_weights)


        # print 'Smooth L1 loss: ', F.smooth_l1_loss(rpn_bbox_pred, rpn_bbox_targets, size_average=False)
        # print 'fg_cnt', fg_cnt
        rpn_loss_box = F.smooth_l1_loss(rpn_bbox_pred, rpn_bbox_targets, size_average=False) /  (fg_cnt + 1e-4)
        # print 'rpn_loss_box', rpn_loss_box
        # print rpn_loss_box

        return rpn_cross_entropy, rpn_loss_box
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号