def updateOutput(self, input):
# lazy initialize buffers
self._input = self._input or input.new()
self._weight = self._weight or self.weight.new()
self._expand = self._expand or self.output.new()
self._expand2 = self._expand2 or self.output.new()
self._repeat = self._repeat or self.output.new()
self._repeat2 = self._repeat2 or self.output.new()
inputSize, outputSize = self.weight.size(0), self.weight.size(1)
# y_j = || w_j - x || = || x - w_j ||
assert input.dim() == 2
batchSize = input.size(0)
self._view(self._input, input, batchSize, inputSize, 1)
self._expand = self._input.expand(batchSize, inputSize, outputSize)
# make the expanded tensor contiguous (requires lots of memory)
self._repeat.resize_as_(self._expand).copy_(self._expand)
self._weight = self.weight.view(1, inputSize, outputSize)
self._expand2 = self._weight.expand_as(self._repeat)
if torch.typename(input) == 'torch.cuda.FloatTensor':
# TODO: after adding new allocators this can be changed
# requires lots of memory, but minimizes cudaMallocs and loops
self._repeat2.resize_as_(self._expand2).copy_(self._expand2)
self._repeat.add_(-1, self._repeat2)
else:
self._repeat.add_(-1, self._expand2)
torch.norm(self.output, self._repeat, 2, 1)
self.output.resize_(batchSize, outputSize)
return self.output
评论列表
文章目录