loss.py 文件源码

python
阅读 44 收藏 0 点赞 0 评论 0

项目:torchbiomed 作者: mattmacy 项目源码 文件源码
def backward(self, grad_output):
        input, _ = self.saved_tensors
        intersect, union = self.intersect, self.union
        target = self.target_
        gt = torch.div(target, union)
        IoU2 = intersect/(union*union)
        pred = torch.mul(input[:, 1], IoU2)
        dDice = torch.add(torch.mul(gt, 2), torch.mul(pred, -4))
        grad_input = torch.cat((torch.mul(dDice, -grad_output[0]),
                                torch.mul(dDice, grad_output[0])), 0)
        return grad_input , None
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号