def build_loss(self, cls_score, bbox_pred, roi_data):
# classification loss
label = roi_data[1].squeeze()
fg_cnt = torch.sum(label.data.ne(0))
bg_cnt = label.data.numel() - fg_cnt
# for log
if self.debug:
maxv, predict = cls_score.data.max(1)
self.tp = torch.sum(predict[:fg_cnt].eq(label.data[:fg_cnt])) if fg_cnt > 0 else 0
self.tf = torch.sum(predict[fg_cnt:].eq(label.data[fg_cnt:]))
self.fg_cnt = fg_cnt
self.bg_cnt = bg_cnt
ce_weights = torch.ones(cls_score.size()[1])
ce_weights[0] = float(fg_cnt) / bg_cnt
ce_weights = ce_weights.cuda()
cross_entropy = F.cross_entropy(cls_score, label, weight=ce_weights)
# bounding box regression L1 loss
bbox_targets, bbox_inside_weights, bbox_outside_weights = roi_data[2:]
bbox_targets = torch.mul(bbox_targets, bbox_inside_weights)
bbox_pred = torch.mul(bbox_pred, bbox_inside_weights)
loss_box = F.smooth_l1_loss(bbox_pred, bbox_targets, size_average=False) / (fg_cnt + 1e-4)
return cross_entropy, loss_box
评论列表
文章目录