def build_loss_objectiveness(self, region_objectiveness, targets):
loss_objectiveness = F.cross_entropy(region_objectiveness, targets)
maxv, predict = region_objectiveness.data.max(1)
labels = targets.squeeze()
fg_cnt = torch.sum(labels.data.ne(0))
bg_cnt = labels.data.numel() - fg_cnt
if fg_cnt > 0:
self.tp_reg = torch.sum(predict[:fg_cnt].eq(labels.data[:fg_cnt]))
else:
self.tp_reg = 0.
if bg_cnt > 0:
self.tf_reg = torch.sum(predict[fg_cnt:].eq(labels.data[fg_cnt:]))
else:
self.tp_reg = 0.
self.fg_cnt_reg = fg_cnt
self.bg_cnt_reg = bg_cnt
return loss_objectiveness
评论列表
文章目录