def compile(cls, source_net, compiled_net):
"""Add observed nodes to the computation graph.
Parameters
----------
source_net : nx.DiGraph
compiled_net : nx.DiGraph
Returns
-------
compiled_net : nx.Digraph
"""
logger.debug("{} compiling...".format(cls.__name__))
observable = []
uses_observed = []
for node in nx.topological_sort(source_net):
state = source_net.node[node]
if state.get('_observable'):
observable.append(node)
cls.make_observed_copy(node, compiled_net)
elif state.get('_uses_observed'):
uses_observed.append(node)
obs_node = cls.make_observed_copy(node, compiled_net, args_to_tuple)
# Make edge to the using node
compiled_net.add_edge(obs_node, node, param='observed')
else:
continue
# Copy the edges
if not state.get('_stochastic'):
obs_node = observed_name(node)
for parent in source_net.predecessors(node):
if parent in observable:
link_parent = observed_name(parent)
else:
link_parent = parent
compiled_net.add_edge(link_parent, obs_node, source_net[parent][node].copy())
# Check that there are no stochastic nodes in the ancestors
for node in uses_observed:
# Use the observed version to query observed ancestors in the compiled_net
obs_node = observed_name(node)
for ancestor_node in nx.ancestors(compiled_net, obs_node):
if '_stochastic' in source_net.node.get(ancestor_node, {}):
raise ValueError("Observed nodes must be deterministic. Observed "
"data depends on a non-deterministic node {}."
.format(ancestor_node))
return compiled_net
评论列表
文章目录