def backward(self, grad_output):
input, _ = self.saved_tensors
intersect, union = self.intersect, self.union
target = self.target_
gt = torch.div(target, union)
IoU2 = intersect/(union*union)
pred = torch.mul(input[:, 1], IoU2)
dDice = torch.add(torch.mul(gt, 2), torch.mul(pred, -4))
grad_input = torch.cat((torch.mul(dDice, -grad_output[0]),
torch.mul(dDice, grad_output[0])), 0)
return grad_input , None
评论列表
文章目录