MSDN_base.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:MSDN 作者: yikang-li 项目源码 文件源码
def build_loss_object(self, cls_score, bbox_pred, roi_data):
        # classification loss
        label = roi_data[1].squeeze()
        fg_cnt = torch.sum(label.data.ne(0))
        bg_cnt = label.data.numel() - fg_cnt

        ce_weights = np.sqrt(self.object_loss_weight)
        ce_weights[0] = float(fg_cnt) / (bg_cnt + 1e-5)
        ce_weights = ce_weights.cuda()

        maxv, predict = cls_score.data.max(1)
        if fg_cnt > 0:
            self.tp = torch.sum(predict[:fg_cnt].eq(label.data[:fg_cnt]))
        else:
            self.tp = 0.
        if bg_cnt > 0:
            self.tf = torch.sum(predict[fg_cnt:].eq(label.data[fg_cnt:]))
        else:
            self.tp = 0.
        self.fg_cnt = fg_cnt
        self.bg_cnt = bg_cnt

        # print '[object]:'
        # if predict.sum() > 0:
        # print predict

        # print 'accuracy: %2.2f%%' % (((self.tp + self.tf) / float(fg_cnt + bg_cnt)) * 100)
        # print predict
        cross_entropy = F.cross_entropy(cls_score, label, weight=ce_weights)
        # print cross_entropy

        # bounding box regression L1 loss
        bbox_targets, bbox_inside_weights, bbox_outside_weights = roi_data[2:]

        # b = bbox_targets.data.cpu().numpy()

        bbox_targets = torch.mul(bbox_targets, bbox_inside_weights)
        bbox_pred = torch.mul(bbox_pred, bbox_inside_weights)

        # a = bbox_pred.data.cpu().numpy()
        loss_box = F.smooth_l1_loss(bbox_pred, bbox_targets, size_average=False) / (fg_cnt + 1e-5)
        # print loss_box

        return cross_entropy, loss_box
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号