def updateOutput(self, input):
# lazy initialize buffers
if self._input is None:
self._input = input.new()
if self._weight is None:
self._weight = self.weight.new()
if self._expand is None:
self._expand = self.output.new()
if self._expand2 is None:
self._expand2 = self.output.new()
if self._repeat is None:
self._repeat = self.output.new()
if self._repeat2 is None:
self._repeat2 = self.output.new()
inputSize, outputSize = self.weight.size(0), self.weight.size(1)
# y_j = || w_j - x || = || x - w_j ||
assert input.dim() == 2
batchSize = input.size(0)
self._view(self._input, input, batchSize, inputSize, 1)
self._expand = self._input.expand(batchSize, inputSize, outputSize)
# make the expanded tensor contiguous (requires lots of memory)
self._repeat.resize_as_(self._expand).copy_(self._expand)
self._weight = self.weight.view(1, inputSize, outputSize)
self._expand2 = self._weight.expand_as(self._repeat)
if torch.typename(input) == 'torch.cuda.FloatTensor':
# TODO: after adding new allocators this can be changed
# requires lots of memory, but minimizes cudaMallocs and loops
self._repeat2.resize_as_(self._expand2).copy_(self._expand2)
self._repeat.add_(-1, self._repeat2)
else:
self._repeat.add_(-1, self._expand2)
torch.norm(self._repeat, 2, 1, out=self.output)
self.output.resize_(batchSize, outputSize)
return self.output
评论列表
文章目录