MSDN_base.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:MSDN 作者: yikang-li 项目源码 文件源码
def build_loss_cls(self, cls_score, labels):
        labels = labels.squeeze()
        fg_cnt = torch.sum(labels.data.ne(0))
        bg_cnt = labels.data.numel() - fg_cnt

        ce_weights = np.sqrt(self.predicate_loss_weight)
        ce_weights[0] = float(fg_cnt) / (bg_cnt + 1e-5)
        ce_weights = ce_weights.cuda()
        # print '[relationship]:'
        # print 'ce_weights:'
        # print ce_weights
        # print 'cls_score:'
        # print cls_score 
        # print 'labels'
        # print labels
        ce_weights = ce_weights.cuda()
        cross_entropy = F.cross_entropy(cls_score, labels, weight=ce_weights)

        maxv, predict = cls_score.data.max(1)
        # if DEBUG:
        # print '[predicate]:'
        # if predict.sum() > 0:
        # print predict
        # print 'labels'
        # print labels

        if fg_cnt == 0:
            tp = 0
        else:
            tp = torch.sum(predict[bg_cnt:].eq(labels.data[bg_cnt:]))
        tf = torch.sum(predict[:bg_cnt].eq(labels.data[:bg_cnt]))
        fg_cnt = fg_cnt
        bg_cnt = bg_cnt

        return cross_entropy, tp, tf, fg_cnt, bg_cnt
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号