MarginRankingCriterion.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:pytorch-dist 作者: apaszke 项目源码 文件源码
def updateGradInput(self, input, y):
        if input[0].size(0) == 1:
            dist = -y * (input[0][0]-input[1][0]) + self.margin
            if dist < 0:
                self.gradInput[0][0] = 0
                self.gradInput[1][0] = 0
            else:
                self.gradInput[0][0] = -y
                self.gradInput[1][0] = y
        else:
            self.dist = self.dist or input[0].new()
            self.dist = self.dist.resize_as_(input[0]).copy_(input[0])
            dist = self.dist

            dist.add_(-1, input[1])
            dist.mul_(-1).mul_(y)
            dist.add_(self.margin)

            self.mask = self.mask or input[0].new()
            self.mask = self.mask.resize_as_(input[0]).copy_(dist)
            mask = self.mask

            torch.ge(mask, dist, 0)

            self.gradInput[0].resize_(dist.size())
            self.gradInput[1].resize_(dist.size())

            self.gradInput[0].copy_(mask)
            self.gradInput[0].mul_(-1).mul_(y)
            self.gradInput[1].copy_(mask)
            self.gradInput[1].mul_(y)

            if self.sizeAverage:
                self.gradInput[0].div_(y.size(0))
                self.gradInput[1].div_(y.size(0))

        return self.gradInput
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号