def updateOutput(self, input):
outs = []
for i in range(len(self.modules)):
currentOutput = self.modules[i].updateOutput(input)
outs.append(currentOutput)
if i == 0:
size = list(currentOutput.size())
else:
size[self.dimension] += currentOutput.size(self.dimension)
self.size = torch.Size(size)
self.output.resize_(self.size)
offset = 0
for i, module in enumerate(self.modules):
currentOutput = outs[i]
self.output.narrow(self.dimension, offset, currentOutput.size(self.dimension)).copy_(currentOutput)
offset = offset + currentOutput.size(self.dimension)
return self.output
评论列表
文章目录