def _find_necessary_steps(self, outputs, inputs):
"""
Determines what graph steps need to pe run to get to the requested
outputs from the provided inputs. Eliminates steps that come before
(in topological order) any inputs that have been provided. Also
eliminates steps that are not on a path from the provided inputs to
the requested outputs.
:param list outputs:
A list of desired output names. This can also be ``None``, in which
case the necessary steps are all graph nodes that are reachable
from one of the provided inputs.
:param dict inputs:
A dictionary mapping names to values for all provided inputs.
:returns:
Returns a list of all the steps that need to be run for the
provided inputs and requested outputs.
"""
if not outputs:
# If caller requested all outputs, the necessary nodes are all
# nodes that are reachable from one of the inputs. Ignore input
# names that aren't in the graph.
necessary_nodes = set()
for input_name in iter(inputs):
if self.graph.has_node(input_name):
necessary_nodes |= nx.descendants(self.graph, input_name)
else:
# If the caller requested a subset of outputs, find any nodes that
# are made unecessary because we were provided with an input that's
# deeper into the network graph. Ignore input names that aren't
# in the graph.
unnecessary_nodes = set()
for input_name in iter(inputs):
if self.graph.has_node(input_name):
unnecessary_nodes |= nx.ancestors(self.graph, input_name)
# Find the nodes we need to be able to compute the requested
# outputs. Raise an exception if a requested output doesn't
# exist in the graph.
necessary_nodes = set()
for output_name in outputs:
if not self.graph.has_node(output_name):
raise ValueError("graphkit graph does not have an output "
"node named %s" % output_name)
necessary_nodes |= nx.ancestors(self.graph, output_name)
# Get rid of the unnecessary nodes from the set of necessary ones.
necessary_nodes -= unnecessary_nodes
# Return an ordered list of the needed steps.
return [step for step in self.steps if step in necessary_nodes]
评论列表
文章目录