def build_loss(self, rpn_cls_score_reshape, rpn_bbox_pred, rpn_data, is_region=False):
# classification loss
rpn_cls_score = rpn_cls_score_reshape.permute(0, 2, 3, 1).contiguous().view(-1, 2)
rpn_label = rpn_data[0]
# print rpn_label.size(), rpn_cls_score.size()
rpn_keep = Variable(rpn_label.data.ne(-1).nonzero().squeeze()).cuda()
rpn_cls_score = torch.index_select(rpn_cls_score, 0, rpn_keep)
rpn_label = torch.index_select(rpn_label, 0, rpn_keep)
fg_cnt = torch.sum(rpn_label.data.ne(0))
bg_cnt = rpn_label.data.numel() - fg_cnt
# ce_weights = torch.ones(rpn_cls_score.size()[1])
# ce_weights[0] = float(fg_cnt) / bg_cnt
# ce_weights = ce_weights.cuda()
_, predict = torch.max(rpn_cls_score.data, 1)
error = torch.sum(torch.abs(predict - rpn_label.data))
# try:
if predict.size()[0] < 256:
print predict.size()
print rpn_label.size()
print fg_cnt
if is_region:
self.tp_region = torch.sum(predict[:fg_cnt].eq(rpn_label.data[:fg_cnt]))
self.tf_region = torch.sum(predict[fg_cnt:].eq(rpn_label.data[fg_cnt:]))
self.fg_cnt_region = fg_cnt
self.bg_cnt_region = bg_cnt
if DEBUG:
print 'accuracy: %2.2f%%' % ((self.tp + self.tf) / float(fg_cnt + bg_cnt) * 100)
else:
self.tp = torch.sum(predict[:fg_cnt].eq(rpn_label.data[:fg_cnt]))
self.tf = torch.sum(predict[fg_cnt:].eq(rpn_label.data[fg_cnt:]))
self.fg_cnt = fg_cnt
self.bg_cnt = bg_cnt
if DEBUG:
print 'accuracy: %2.2f%%' % ((self.tp + self.tf) / float(fg_cnt + bg_cnt) * 100)
rpn_cross_entropy = F.cross_entropy(rpn_cls_score, rpn_label)
# print rpn_cross_entropy
# box loss
rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights = rpn_data[1:]
rpn_bbox_targets = torch.mul(rpn_bbox_targets, rpn_bbox_inside_weights)
rpn_bbox_pred = torch.mul(rpn_bbox_pred, rpn_bbox_inside_weights)
# print 'Smooth L1 loss: ', F.smooth_l1_loss(rpn_bbox_pred, rpn_bbox_targets, size_average=False)
# print 'fg_cnt', fg_cnt
rpn_loss_box = F.smooth_l1_loss(rpn_bbox_pred, rpn_bbox_targets, size_average=False) / (fg_cnt + 1e-4)
# print 'rpn_loss_box', rpn_loss_box
# print rpn_loss_box
return rpn_cross_entropy, rpn_loss_box
评论列表
文章目录