losses.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:jaccardSegment 作者: bermanmaxim 项目源码 文件源码
def crossentropyloss(logits, label):
    mask = (label.view(-1) != VOID_LABEL)
    nonvoid = mask.long().sum()
    if nonvoid == 0:
        # only void pixels, the gradients should be 0
        return logits.sum() * 0.
    # if nonvoid == mask.numel():
    #     # no void pixel, use builtin
    #     return F.cross_entropy(logits, Variable(label))
    target = label.view(-1)[mask]
    C = logits.size(1)
    logits = logits.permute(0, 2, 3, 1)  # B, H, W, C
    logits = logits.contiguous().view(-1, C)
    mask2d = mask.unsqueeze(1).expand(mask.size(0), C).contiguous().view(-1)
    logits = logits[mask2d].view(-1, C)
    loss = F.cross_entropy(logits, Variable(target))
    return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号