graph.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:inferno 作者: inferno-pytorch 项目源码 文件源码
def forward(self, *inputs):
        self.assert_graph_is_valid()
        input_nodes = self.input_nodes
        output_nodes = self.output_nodes
        assert len(inputs) == len(input_nodes), "Was expecting {} " \
                                                "arguments for as many input nodes, got {}."\
            .format(len(input_nodes), len(inputs))
        # Unpack inputs to input nodes
        for input, input_node in zip(inputs, input_nodes):
            self.forward_through_node(input_node, input=input)
        # Toposort the graph
        toposorted = topological_sort(self.graph)
        # Remove all input and output nodes
        toposorted = [name for name in toposorted
                      if name not in input_nodes and name not in output_nodes]
        # Forward
        for node in toposorted:
            self.forward_through_node(node)
        # Read outputs from output nodes
        outputs = []
        for output_node in output_nodes:
            # Get all incoming edges to output node
            outputs_from_node = [self.graph[incoming][this]['payload']
                                 for incoming, this in self.graph.in_edges(output_node)]
            outputs.append(pyu.from_iterable(outputs_from_node))
        # Clear payloads for next pass
        self.clear_payloads()
        # Done.
        return pyu.from_iterable(outputs)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号