MarginRankingCriterion.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码
def updateGradInput(self, input, y):
        if input[0].size(0) == 1:
            dist = -y * (input[0][0] - input[1][0]) + self.margin
            if dist < 0:
                self.gradInput[0][0] = 0
                self.gradInput[1][0] = 0
            else:
                self.gradInput[0][0] = -y
                self.gradInput[1][0] = y
        else:
            if self.dist is None:
                self.dist = input[0].new()
            self.dist = self.dist.resize_as_(input[0]).copy_(input[0])
            dist = self.dist

            dist.add_(-1, input[1])
            dist.mul_(-1).mul_(y)
            dist.add_(self.margin)

            if self.mask is None:
                self.mask = input[0].new()
            self.mask = self.mask.resize_as_(input[0]).copy_(dist)
            mask = self.mask

            torch.ge(dist, 0, out=mask)

            self.gradInput[0].resize_(dist.size())
            self.gradInput[1].resize_(dist.size())

            self.gradInput[0].copy_(mask)
            self.gradInput[0].mul_(-1).mul_(y)
            self.gradInput[1].copy_(mask)
            self.gradInput[1].mul_(y)

            if self.sizeAverage:
                self.gradInput[0].div_(y.size(0))
                self.gradInput[1].div_(y.size(0))

        return self.gradInput
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号