def compile(cls, source_net, compiled_net):
"""Remove redundant nodes from the computation graph.
Parameters
----------
source_net : nx.DiGraph
compiled_net : nx.DiGraph
Returns
-------
compiled_net : nx.Digraph
"""
logger.debug("{} compiling...".format(cls.__name__))
outputs = compiled_net.graph['outputs']
output_ancestors = nbunch_ancestors(compiled_net, outputs)
for node in compiled_net.nodes():
if node not in output_ancestors:
compiled_net.remove_node(node)
return compiled_net
评论列表
文章目录