def updateOutput(self, input):
outs = []
for i in range(len(self.modules)):
currentOutput = self.modules[i].updateOutput(input)
outs.append(currentOutput)
if i == 0:
size = list(currentOutput.size())
else:
size[self.dimension] += currentOutput.size(self.dimension)
for dim in range(len(self.outputSize)):
if dim != self.dimension:
# take the maximum size (shouldn't change anything for batch dim)
size[dim] = max(size[dim], currentOutput.size(dim))
self.outputSize = torch.Size(size)
self.output.resize_(self.outputSize).zero_() # zero for padding
offset = 0
for i, module in enumerate(self.modules):
currentOutput = outs[i]
outputWindow = self.windowNarrow(self.output, currentOutput, offset)
outputWindow.copy_(currentOutput)
offset = offset + currentOutput.size(self.dimension)
return self.output
评论列表
文章目录