CosineEmbeddingCriterion.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:pytorch-dist 作者: apaszke 项目源码 文件源码
def updateOutput(self, input, y):
        input1, input2 = input[0], input[1]

        # keep backward compatibility
        if not self.buffer:
            self.buffer = input1.new()
            self.w1  = input1.new()
            self.w22 = input1.new()
            self.w  = input1.new()
            self.w32 = input1.new()
            self._outputs = input1.new()

            # comparison operators behave differently from cuda/c implementations
            # TODO: verify name
            if input1.type() == 'torch.cuda.FloatTensor':
                self._idx = torch.cuda.ByteTensor()
            else:
                self._idx = torch.ByteTensor()

        torch.mul(self.buffer, input1, input2)
        torch.sum(self.w1, self.buffer, 1)

        epsilon = 1e-12
        torch.mul(self.buffer, input1, input1)
        torch.sum(self.w22, self.buffer, 1).add_(epsilon)
        # self._outputs is also used as a temporary buffer
        self._outputs.resize_as_(self.w22).fill_(1)
        torch.div(self.w22, self._outputs, self.w22)
        self.w.resize_as_(self.w22).copy_(self.w22)

        torch.mul(self.buffer, input2, input2)
        torch.sum(self.w32, self.buffer, 1).add_(epsilon)
        torch.div(self.w32, self._outputs, self.w32)
        self.w.mul_(self.w32)
        self.w.sqrt_()

        torch.mul(self._outputs, self.w1, self.w)
        self._outputs = self._outputs.select(1, 0)

        torch.eq(self._idx, y, -1)
        self._outputs[self._idx] = self._outputs[self._idx].add_(-self.margin).cmax_(0)
        torch.eq(self._idx, y, 1)
        self._outputs[self._idx] = self._outputs[self._idx].mul_(-1).add_(1)

        self.output = self._outputs.sum()

        if self.sizeAverage:
           self.output = self.output / y.size(0)

        return self.output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号