def forward(self, *inputs):
self.assert_graph_is_valid()
input_nodes = self.input_nodes
output_nodes = self.output_nodes
assert len(inputs) == len(input_nodes), "Was expecting {} " \
"arguments for as many input nodes, got {}."\
.format(len(input_nodes), len(inputs))
# Unpack inputs to input nodes
for input, input_node in zip(inputs, input_nodes):
self.forward_through_node(input_node, input=input)
# Toposort the graph
toposorted = topological_sort(self.graph)
# Remove all input and output nodes
toposorted = [name for name in toposorted
if name not in input_nodes and name not in output_nodes]
# Forward
for node in toposorted:
self.forward_through_node(node)
# Read outputs from output nodes
outputs = []
for output_node in output_nodes:
# Get all incoming edges to output node
outputs_from_node = [self.graph[incoming][this]['payload']
for incoming, this in self.graph.in_edges(output_node)]
outputs.append(pyu.from_iterable(outputs_from_node))
# Clear payloads for next pass
self.clear_payloads()
# Done.
return pyu.from_iterable(outputs)
评论列表
文章目录