def execute(self):
maxLen = max([len(e) for e in self.progs])
for s in range(maxLen):
nodes = []
for i in range(len(self.progs)):
prog = self.progs[i]
if len(prog) <= s:
continue
nodes += [prog[s]]
groupedNodes = {}
for node in nodes:
groupedNodes.setdefault(node.cellInd, []).append(node)
for cellInd, nodes in groupedNodes.items():
arity = nodes[0].arity
cell = self.cells[cellInd]
outData = [node.inpData[0] for node in nodes]
if arity==1:
arg = t.cat(outData, 0)
outData = cell(arg)
outData = t.split(outData, 1, 0)
elif arity==2:
arg1 = t.cat(outData, 0)
arg2 = t.cat([node.inpData[1] for node in nodes], 0)
outData = cell(arg1, arg2)
outData = t.split(outData, 1, 0)
for node, outDat in zip(nodes, outData):
if node.prev is None:
node.outData = outDat
else:
node.prev.inpData += [outDat]
outData = [prog[-1].outData for prog in self.progs]
return t.cat(outData, 0)
评论列表
文章目录