loss.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:pytorch-semseg 作者: meetshah1995 项目源码 文件源码
def cross_entropy2d(input, target, weight=None, size_average=True):
    n, c, h, w = input.size()
    log_p = F.log_softmax(input, dim=1)
    log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
    log_p = log_p[target.view(n * h * w, 1).repeat(1, c) >= 0]
    log_p = log_p.view(-1, c)

    mask = target >= 0
    target = target[mask]
    loss = F.nll_loss(log_p, target, weight=weight, size_average=False)
    if size_average:
        loss /= mask.data.sum()
    return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号