Normalize.py 文件源码

python
阅读 50 收藏 0 点赞 0 评论 0

项目:pytorch-dist 作者: apaszke 项目源码 文件源码
def updateOutput(self, input):
        assert input.dim() == 2
        input_size = input.size()

        self._output = self._output or input.new()
        self.norm = self.norm or input.new()
        self.buffer = self.buffer or input.new()

        self._output.resize_as_(input)

        # specialization for the infinity norm
        if self.p == float('inf'):
            if not self._indices:
                self._indices = torch.cuda.FloatTensor() if torch.typename(self.output) == 'torch.cuda.FloatTensor' \
                    else torch.LongTensor()

            torch.abs(self.buffer, input)
            torch.max(self.norm, self._indices, self.buffer, 1)
            self.norm.add_(self.eps)
        else:
            self.normp = self.normp or input.new()
            if self.p % 2 != 0:
                torch.abs(self.buffer, input).pow_(self.p)
            else:
                torch.pow(self.buffer, input, self.p)

            torch.sum(self.normp, self.buffer, 1).add_(self.eps)
            torch.pow(self.norm, self.normp, 1./self.p)

        torch.div(self._output, input, self.norm.view(-1, 1).expand_as(input))

        self.output = self._output.view(input_size)
        return self.output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号