def binaryXloss(logits, label):
mask = (label.view(-1) != VOID_LABEL)
nonvoid = mask.long().sum()
if nonvoid == 0:
# only void pixels, the gradients should be 0
return logits.sum() * 0.
# if nonvoid == mask.numel():
# # no void pixel, use builtin
# return F.cross_entropy(logits, Variable(label))
target = label.contiguous().view(-1)[mask]
logits = logits.contiguous().view(-1)[mask]
# loss = F.binary_cross_entropy(logits, Variable(target.float()))
loss = StableBCELoss()(logits, Variable(target.float()))
return loss
评论列表
文章目录