def build_loss_cls(self, cls_score, labels):
labels = labels.squeeze()
fg_cnt = torch.sum(labels.data.ne(0))
bg_cnt = labels.data.numel() - fg_cnt
ce_weights = np.sqrt(self.predicate_loss_weight)
ce_weights[0] = float(fg_cnt) / (bg_cnt + 1e-5)
ce_weights = ce_weights.cuda()
# print '[relationship]:'
# print 'ce_weights:'
# print ce_weights
# print 'cls_score:'
# print cls_score
# print 'labels'
# print labels
ce_weights = ce_weights.cuda()
cross_entropy = F.cross_entropy(cls_score, labels, weight=ce_weights)
maxv, predict = cls_score.data.max(1)
# if DEBUG:
# print '[predicate]:'
# if predict.sum() > 0:
# print predict
# print 'labels'
# print labels
if fg_cnt == 0:
tp = 0
else:
tp = torch.sum(predict[bg_cnt:].eq(labels.data[bg_cnt:]))
tf = torch.sum(predict[:bg_cnt].eq(labels.data[:bg_cnt]))
fg_cnt = fg_cnt
bg_cnt = bg_cnt
return cross_entropy, tp, tf, fg_cnt, bg_cnt
评论列表
文章目录