CosineDistance.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:pytorch-dist 作者: apaszke 项目源码 文件源码
def updateGradInput(self, input, gradOutput):
        v1  = input[0]
        v2  = input[1]
        v1, v2 = self._makeContiguous(v1, v2)

        if len(self.gradInput) != 2:
           self.gradInput[0] = self.gradInput[0] or v1.new()
           self.gradInput[1] = self.gradInput[1] or v1.new()
           self.gradInput = self.gradInput[:2]

        gw1 = self.gradInput[0]
        gw2 = self.gradInput[1]
        gw1.resize_as_(v1).copy_(v2)
        gw2.resize_as_(v1).copy_(v1)

        torch.mul(self.buffer, self.w1, self.w22)
        gw1.addcmul_(-1, self.buffer.expand_as(v1), v1)
        gw1.mul_(self.w.expand_as(v1))

        torch.mul(self.buffer, self.w1, self.w32)
        gw2.addcmul_(-1, self.buffer.expand_as(v1), v2)
        gw2.mul_(self.w.expand_as(v1))

        go = gradOutput.view(-1, 1).expand_as(v1)
        gw1.mul_(go)
        gw2.mul_(go)

        return self.gradInput
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号