def cross_entropy2d(pred, target, weight=None, size_average=True):
n, num_classes, h, w = pred.size()
log_p = F.log_softmax(pred)
log_p = channel_first_to_last(log_p).view(-1, num_classes)
target = channel_first_to_last(target).view(-1)
loss = F.nll_loss(log_p, target, weight=weight, size_average=False)
if size_average:
loss /= (h * w * n)
return loss
评论列表
文章目录