WeightedEuclidean.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:pytorch-dist 作者: apaszke 项目源码 文件源码
def accGradParameters(self, input, gradOutput, scale=1):
        inputSize, outputSize = self.weight.size(0), self.weight.size(1)

        """
        dy_j   2 * c_j * c_j * (w_j - x)    c_j * c_j * (w_j - x)
        ---- = -------------------------- = ---------------------
        dw_j    2 || c_j * (w_j - x) ||             y_j

        dy_j    2 * c_j * (w_j - x)^2    c_j * (w_j - x)^2
        ---- = ----------------------- = -----------------
        dc_j   2 || c_j * (w_j - x) ||         y_j
        #"""
        # assumes a preceding call to updateGradInput
        if input.dim() == 1:
            self.gradWeight.add_(-scale, self._repeat2)

            self._repeat.div_(self.diagCov)
            self._repeat.mul_(self._repeat)
            self._repeat.mul_(self.diagCov)

            if torch.type(input) == 'torch.cuda.FloatTensor':
                self._repeat2.resize_as_(self._expand4).copy_(self._expand4)
                self._repeat2.mul_(self._repeat)
            else:
                torch.mul(self._repeat2, self._repeat, self._expand4)


            self.gradDiagCov.add_(self._repeat2)
        elif input.dim() == 2:
            self._sum = self._sum or input.new()
            torch.sum(self._sum, self._repeat2, 0)
            self._sum.resize_(inputSize, outputSize)
            self.gradWeight.add_(-scale, self._sum)

            if input.type() == 'torch.cuda.FloatTensor':
                # requires lots of memory, but minimizes cudaMallocs and loops
                self._repeat.div_(self._repeat3)
                self._repeat.mul_(self._repeat)
                self._repeat.mul_(self._repeat3)
                self._repeat2.resize_as_(self._expand4).copy_(self._expand4)
                self._repeat.mul_(self._repeat2)
            else:
                self._repeat.div_(self._expand3)
                self._repeat.mul_(self._repeat)
                self._repeat.mul_(self._expand3)
                self._repeat.mul_(self._expand4)


            torch.sum(self._sum, self._repeat, 0)
            self._sum.resize_(inputSize, outputSize)
            self.gradDiagCov.add_(scale, self._sum)
        else:
            raise RuntimeError("1D or 2D input expected")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号