def compute_down_msg(self, inc_node_msg, node_to_span_off_belief_idx,
node_to_span_on_start_belief_idx,
node_to_span_on_end_belief_idx,
parent_on_down_to_sum_tree_idx,
parent_off_down_to_sum_tree_idx):
"""Compute downward BP messages for this layer of the tree.
Args:
inc_node_msg: incoming messages from parent variables.
node_to_span_off_belief_idx: map from node marginals at this layer to
corresponding span-off marginals.
node_to_span_on_start_belief_idx: map marking start of each span marginal.
node_to_span_on_end_belief_idx: map marking end of each span marginal.
parent_on_down_to_sum_tree_idx: map from marginal of parent-on
variable down to child variable.
parent_off_down_to_sum_tree_idx: map from marginal of parent-off
variable down to child variable.
Returns:
span_off_marginals:
out_msg:
"""
node_marginals = self.up_node_msg * inc_node_msg
span_off_beliefs = padded_gather_nd(node_marginals,
node_to_span_off_belief_idx, 3, 4)
cumulative_node_beliefs = tf.cumsum(node_marginals, 2)
span_on_start_cumulative_belief = padded_gather_nd(
cumulative_node_beliefs, node_to_span_on_start_belief_idx, 3, 4)
span_on_end_cumulative_belief = padded_gather_nd(
cumulative_node_beliefs, node_to_span_on_end_belief_idx, 3, 4)
span_on_beliefs = (
span_on_end_cumulative_belief - span_on_start_cumulative_belief)
span_belief_normalizer = span_on_beliefs + span_off_beliefs
span_off_marginals = su.safe_divide(span_off_beliefs,
span_belief_normalizer)
out_msg = padded_gather_nd(inc_node_msg, parent_on_down_to_sum_tree_idx, 3,
4)
out_msg += padded_gather_nd(inc_node_msg, parent_off_down_to_sum_tree_idx,
3, 4)
return span_off_marginals, out_msg
fft_tree_constrained_inference.py 文件源码
python
阅读 27
收藏 0
点赞 0
评论 0
评论列表
文章目录