def build_loss_bbox(self, bbox_pred, roi_data):
bbox_targets, bbox_inside_weights, bbox_outside_weights = roi_data[2:]
bbox_targets = torch.mul(bbox_targets, bbox_inside_weights)
bbox_pred = torch.mul(bbox_pred, bbox_inside_weights)
fg_cnt = torch.sum(bbox_inside_weights[:, 0].data.ne(0))
loss_box = F.smooth_l1_loss(bbox_pred, bbox_targets, size_average=False) / (fg_cnt + 1e-5)
return loss_box
评论列表
文章目录