Normalize.py 文件源码

python
阅读 52 收藏 0 点赞 0 评论 0

项目:pytorch 作者: pytorch 项目源码 文件源码
def updateOutput(self, input):
        assert input.dim() == 2
        input_size = input.size()

        if self._output is None:
            self._output = input.new()
        if self.norm is None:
            self.norm = input.new()
        if self.buffer is None:
            self.buffer = input.new()

        self._output.resize_as_(input)

        # specialization for the infinity norm
        if self.p == float('inf'):
            if not self._indices:
                self._indices = torch.cuda.FloatTensor() if torch.typename(self.output) == 'torch.cuda.FloatTensor' \
                    else torch.LongTensor()

            torch.abs(input, out=self.buffer)
            torch.max(self._indices, self.buffer, 1, out=self.norm, keepdim=True)
            self.norm.add_(self.eps)
        else:
            if self.normp is None:
                self.normp = input.new()
            if self.p % 2 != 0:
                torch.abs(input, out=self.buffer).pow_(self.p)
            else:
                torch.pow(input, self.p, out=self.buffer)

            torch.sum(self.buffer, 1, out=self.normp, keepdim=True).add_(self.eps)
            torch.pow(self.normp, 1. / self.p, out=self.norm)

        torch.div(input, self.norm.view(-1, 1).expand_as(input), out=self._output)

        self.output = self._output.view(input_size)
        return self.output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号