def updateOutput(self, input):
assert input.dim() == 2
input_size = input.size()
if self._output is None:
self._output = input.new()
if self.norm is None:
self.norm = input.new()
if self.buffer is None:
self.buffer = input.new()
self._output.resize_as_(input)
# specialization for the infinity norm
if self.p == float('inf'):
if not self._indices:
self._indices = torch.cuda.FloatTensor() if torch.typename(self.output) == 'torch.cuda.FloatTensor' \
else torch.LongTensor()
torch.abs(input, out=self.buffer)
torch.max(self._indices, self.buffer, 1, out=self.norm, keepdim=True)
self.norm.add_(self.eps)
else:
if self.normp is None:
self.normp = input.new()
if self.p % 2 != 0:
torch.abs(input, out=self.buffer).pow_(self.p)
else:
torch.pow(input, self.p, out=self.buffer)
torch.sum(self.buffer, 1, out=self.normp, keepdim=True).add_(self.eps)
torch.pow(self.normp, 1. / self.p, out=self.norm)
torch.div(input, self.norm.view(-1, 1).expand_as(input), out=self._output)
self.output = self._output.view(input_size)
return self.output
评论列表
文章目录