def updateOutput(self, input):
# lazy-initialize
if self._output is None:
self._output = input.new()
self._weight = input.new()
self._expand = input.new()
self._repeat = input.new()
self.output.resize_as_(input).copy_(input)
batchSize = input.size(0)
# TODO: expand_as_, view_
self._output = self.output.view(batchSize, -1)
self._weight = self.weight.view(1, -1)
self._expand = self._weight.expand_as(self._output)
if torch.typename(input) == 'torch.cuda.FloatTensor':
self._repeat.resize_as_(self._expand).copy_(self._expand)
self._output.mul_(self._repeat)
else:
self._output.mul_(self._expand)
return self.output
评论列表
文章目录